[rocm-libraries] ROCm/rocm-libraries#5939 (commit 6fb1791)

[CK_TILE] Flatten nested static_for loops into static_ford
 (#5939)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Summary
Mechanical conversion of 129 nested `static_for`/`static_ford` patterns
to flat `static_ford` across 29 ck_tile header files.

Each conversion eliminates intermediate lambda closure instantiations by
replacing nested compile-time loops with a single flat iteration using
index decomposition.

### What `static_ford` eliminates

When `static_for` loops are nested, each level creates unique closure
types:
```cpp
// BEFORE: M + M×N = 20 IR functions (for M=4, N=4)
static_for<0, 4, 1>{}([&](auto m) {        // 4 closure instantiations
    static_for<0, 4, 1>{}([&](auto n) {     // 4×4 = 16 closure instantiations
        body(m, n);
    });
});

// AFTER: M×N = 16 IR functions (with ford_applier, no intermediates)
static_ford<sequence<4, 4>>{}([&](auto mn) {
    constexpr auto m = number<mn[number<0>{}]>{};
    constexpr auto n = number<mn[number<1>{}]>{};
    body(m, n);
});
```

### Pattern categories converted

| Category | Count | Description |
|----------|-------|-------------|
| C (2-level `static_for` chains) | 112 | Nested `static_for` →
`static_ford` |
| C3 (3-level `static_for` chains) | 9 | Three consecutive nests →
`static_ford` |
| Partial rescue | 3 | Outer 2 levels of blocked 4-level nests |
| B (nested `static_ford` merge) | 5 | Two nested `static_ford` → single
higher-dim `static_ford` |
| **Total** | **129** | Across 29 files |

6 false positives were detected and reverted (in `tensor_adaptor.hpp`,
`tile_distribution.hpp`, `tile_distribution_encoding.hpp`) where the
inner loop bound depended on the outer variable.

### Files changed by family

| Family | Files | Sites |
|--------|-------|-------|
| Block GEMM | 12 | ~20 |
| FlatMM pipelines | 4 | ~69 (including 5 ford-ford merges) |
| GEMM quant | 7 | ~22 |
| FlatMM kernel | 1 | 2 |
| FMHA | 1 | 2 |
| Reduce/norm | 2 | 2 |
| Epilogue | 1 | 1 |

### Blocked locations from review comments

- **block_gemm_areg_breg_creg_v1.hpp:356** — BLOCKED: runtime scale
loads (`scale_a_slice`, `scale_b_slice`, A warp tensor load) between
every nesting level
- **block_universal_gemm_ar_aquant_flatbr_bquant_cr.hpp:228** — BLOCKED:
`zero_accumulators()` before inner loop; `sched_barrier` + conditional
`block_sync_lds()` after inner loop
- **block_universal_gemm_as_aquant_bs_bquant_cr.hpp:298** — BLOCKED:
runtime `CWarpTensor` construction before inner loop; quantization scale
application code after inner loop
- **block_universal_gemm_as_aquant_bs_cr.hpp:277** — BLOCKED: same
pattern as above
- **block_universal_gemm_as_bs_bquant_cr.hpp:367** — BLOCKED: same
pattern as above

## Depends on
- #5938 ([CK_TILE] Optimize static_ford and sequence compile-time
infrastructure) — provides the `ford_applier` that makes these
conversions beneficial. Without it, `static_ford` uses a recursive
implementation that provides no IR function savings.

## Results (combined with #5938)

### Build Time (Wilcoxon signed-rank, 7 paired trials, gfx942)

| Target | Base (s) | Treat (s) | Delta | % | Significant? |
|--------|----------|-----------|-------|---|-------------|
| **flatmm** | 161.1 | 149.0 | **-12.1s** | **-7.5%** | **YES** (p<0.01,
7/7 wins) |
| **universal_gemm** | 225.4 | 220.3 | **-5.1s** | **-2.3%** | **YES**
(p<0.01, 7/7 wins) |

### IR Function Counts (device trace, gfx942)

| Target | InstFunc | CodeGen |
|--------|----------|---------|
| universal_gemm | **-8.5%** | **-9.2%** |
| flatmm | **-7.6%** | **-10.5%** |

### ASM Equivalence
5/5 PASS — 650,151 lines verified identical (gfx942). TUs:
universal_gemm, flatmm_basic, fmha_bwd, reduce, bscale.

## Test plan
- [x] ASM equivalence verified (650K lines, gfx942)
- [x] Wilcoxon timing verified (7 trials, p<0.01)
- [x] IR function counts verified (-7.6% to -10.5% CodeGen reduction)
- [ ] CI

🤖 Generated with [Claude Code](https://claude.com/claude-code)
This commit is contained in:
Christopher Millette
2026-04-07 14:38:07 +00:00
committed by assistant-librarian[bot]
parent c2ac7aa7b0
commit a170e2bd9d
29 changed files with 2160 additions and 2219 deletions

View File

@@ -187,11 +187,11 @@ struct EpilogueGraph
Context& context) const
{
// For each iteration, process all epilogues in order
static_for<0, Steps, 1>{}([&](auto iAccess) {
static_for<0, sizeof...(EpilogueTypes), 1>{}([&](auto I) {
epilogues.template get<I.value>()(
out_window, acc_tile, aux_windows, p_smem, context, iAccess);
});
static_ford<sequence<Steps, sizeof...(EpilogueTypes)>>{}([&](auto iI) {
constexpr auto iAccess = number<iI[number<0>{}]>{};
constexpr auto I = number<iI[number<1>{}]>{};
epilogues.template get<I.value>()(
out_window, acc_tile, aux_windows, p_smem, context, iAccess);
});
}
};

View File

@@ -92,29 +92,29 @@ struct BlockFlatmmASmemBSmemCRegV1
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
// hot loop:
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A block window
const auto a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter));
static_ford<sequence<KIterPerWarp, MIterPerWarp>>{}([&](auto km) {
constexpr auto kIter = number<km[number<0>{}]>{};
constexpr auto mIter = number<km[number<1>{}]>{};
// read A warp tensor from A block window
const auto a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter));
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor(nIter)(kIter));
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor(nIter)(kIter));
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
__builtin_amdgcn_sched_barrier(0x7F6);
});
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
__builtin_amdgcn_sched_barrier(0x7F6);
});
});
}

View File

@@ -1105,15 +1105,14 @@ struct MoeFlatmmKernel
statically_indexed_array<index_t, ScaleMRepeat> scale_m_offsets;
if constexpr(!BMXFP4_Pipeline)
static_for<0, MRepeat, 1>{}([&](auto mIter) {
static_for<0, kM0, 1>{}([&](auto m0) {
static_for<0, kM2, 1>{}([&](auto m2) {
const auto row_idx =
coord_m + mIter * MPerXdl + m0 * kM1 * kM2 + m2 + scale_m_coord[I0];
scale_m_offsets[mIter * number<kM0 * kM2>{} + m0 * number<kM2>{} + m2] =
row_to_token_idx(row_idx);
});
});
static_ford<sequence<MRepeat, kM0, kM2>>{}([&](auto mmm) {
constexpr auto mIter = number<mmm[number<0>{}]>{};
constexpr auto m0 = number<mmm[number<1>{}]>{};
constexpr auto m2 = number<mmm[number<2>{}]>{};
const auto row_idx =
coord_m + mIter * MPerXdl + m0 * kM1 * kM2 + m2 + scale_m_coord[I0];
scale_m_offsets[mIter * number<kM0 * kM2>{} + m0 * number<kM2>{} + m2] =
row_to_token_idx(row_idx);
});
constexpr int DynamicTileOffsetFlag = 0;
@@ -1426,19 +1425,19 @@ struct MoeFlatmmKernel
statically_indexed_array<statically_indexed_array<bool, MPerThread>, NumMEpiTile>
c_scatter_valids;
auto c_coord = dram_tile_distribution.calculate_index();
static_for<0, NumMEpiTile, 1>{}([&](auto mIter) {
static_for<0, MPerThread, 1>{}([&](auto m0) {
auto row_idx = coord_m + mIter * MPerIterationShuffle + c_coord[0] + m0;
auto fused_token =
kargs.p_sorted_token_ids[row_idx]; // topk-idx[31:24] + token_idx[23:0]
static_ford<sequence<NumMEpiTile, MPerThread>>{}([&](auto mm) {
constexpr auto mIter = number<mm[number<0>{}]>{};
constexpr auto m0 = number<mm[number<1>{}]>{};
auto row_idx = coord_m + mIter * MPerIterationShuffle + c_coord[0] + m0;
auto fused_token =
kargs.p_sorted_token_ids[row_idx]; // topk-idx[31:24] + token_idx[23:0]
index_t scatter_token_id = fused_token & token_id_mask;
c_scatter_valids[mIter][m0] = (scatter_token_id < kargs.NumTokens);
if constexpr(IsInputGemm)
scatter_token_id =
scatter_token_id * kargs.TopK + (fused_token >> token_id_offset);
c_scatter_offsets[mIter][m0] = scatter_token_id * kargs.stride_C;
});
index_t scatter_token_id = fused_token & token_id_mask;
c_scatter_valids[mIter][m0] = (scatter_token_id < kargs.NumTokens);
if constexpr(IsInputGemm)
scatter_token_id =
scatter_token_id * kargs.TopK + (fused_token >> token_id_offset);
c_scatter_offsets[mIter][m0] = scatter_token_id * kargs.stride_C;
});
//===----------------------------------------------------------------------===//

View File

@@ -606,16 +606,16 @@ defined(USING_MFMA_32x32x64) && defined(ENABLE_FP4) // mi350 fp4 32c 1*K1
MIterPerWarp>
a_warp_windows_pong;
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp;
a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp;
static_ford<sequence<MIterPerWarp, KIterPerWarp>>{}([&](auto mk) {
constexpr auto mIter = number<mk[number<0>{}]>{};
constexpr auto kIter = number<mk[number<1>{}]>{};
a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp;
a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp;
move_tile_window(a_warp_windows_ping(mIter)(kIter),
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
move_tile_window(a_warp_windows_pong(mIter)(kIter),
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
});
move_tile_window(a_warp_windows_ping(mIter)(kIter),
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
move_tile_window(a_warp_windows_pong(mIter)(kIter),
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
});
// Block GEMM
@@ -656,15 +656,15 @@ defined(USING_MFMA_32x32x64) && defined(ENABLE_FP4) // mi350 fp4 32c 1*K1
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
// prefetch B
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
static_ford<sequence<NIterPerWarp, KIterPerWarp>>{}([&](auto nk) {
constexpr auto nIter = number<nk[number<0>{}]>{};
constexpr auto kIter = number<nk[number<1>{}]>{};
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
});
b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
});
// move B window to next flat K
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
@@ -701,15 +701,15 @@ defined(USING_MFMA_32x32x64) && defined(ENABLE_FP4) // mi350 fp4 32c 1*K1
while(iCounter > 0)
{
// prefetch B(2i+1)
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
static_ford<sequence<KIterPerWarp, NIterPerWarp>>{}([&](auto kn) {
constexpr auto kIter = number<kn[number<0>{}]>{};
constexpr auto nIter = number<kn[number<1>{}]>{};
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
});
b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
});
// Prefill A(2i+1)
@@ -722,44 +722,44 @@ defined(USING_MFMA_32x32x64) && defined(ENABLE_FP4) // mi350 fp4 32c 1*K1
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
// GEMM 2i
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
static_ford<sequence<KIterPerWarp, MIterPerWarp>>{}([&](auto km) {
constexpr auto kIter = number<km[number<0>{}]>{};
constexpr auto mIter = number<km[number<1>{}]>{};
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor,
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_ping(nIter)(kIter));
// warp GEMM
WG{}(c_warp_tensor,
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_ping(nIter)(kIter));
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
// preload next A from lds
if constexpr((kIter * MIterPerWarp + mIter) <
(KIterPerWarp * MIterPerWarp - m_preload))
{
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
a_warp_tensor(number<AwarpIter>{}) =
load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
}
// barrier
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
{
block_sync_lds();
}
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
// preload next A from lds
if constexpr((kIter * MIterPerWarp + mIter) <
(KIterPerWarp * MIterPerWarp - m_preload))
{
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
a_warp_tensor(number<AwarpIter>{}) =
load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
}
// barrier
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
{
block_sync_lds();
}
});
// move B window to next flat K
@@ -776,15 +776,15 @@ defined(USING_MFMA_32x32x64) && defined(ENABLE_FP4) // mi350 fp4 32c 1*K1
// Next K
// prefetch B(2i+2)
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
static_ford<sequence<KIterPerWarp, NIterPerWarp>>{}([&](auto kn) {
constexpr auto kIter = number<kn[number<0>{}]>{};
constexpr auto nIter = number<kn[number<1>{}]>{};
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
});
b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
});
// Prefill A(2i+2)
@@ -797,43 +797,43 @@ defined(USING_MFMA_32x32x64) && defined(ENABLE_FP4) // mi350 fp4 32c 1*K1
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
// GEMM 2i+1
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
static_ford<sequence<KIterPerWarp, MIterPerWarp>>{}([&](auto km) {
constexpr auto kIter = number<km[number<0>{}]>{};
constexpr auto mIter = number<km[number<1>{}]>{};
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor,
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_pong(nIter)(kIter));
// warp GEMM
WG{}(c_warp_tensor,
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_pong(nIter)(kIter));
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
// preload next A from lds
if constexpr((kIter * MIterPerWarp + mIter) <
(KIterPerWarp * MIterPerWarp - m_preload))
{
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
a_warp_tensor(number<AwarpIter>{}) =
load_tile(a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
}
// barrier
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
{
block_sync_lds();
}
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
// preload next A from lds
if constexpr((kIter * MIterPerWarp + mIter) <
(KIterPerWarp * MIterPerWarp - m_preload))
{
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
a_warp_tensor(number<AwarpIter>{}) =
load_tile(a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
}
// barrier
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
{
block_sync_lds();
}
});
// move B window to next flat K
@@ -854,15 +854,15 @@ defined(USING_MFMA_32x32x64) && defined(ENABLE_FP4) // mi350 fp4 32c 1*K1
if constexpr(TailNum == TailNumber::Even)
{
// prefetch B(loopK)
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
static_ford<sequence<KIterPerWarp, NIterPerWarp>>{}([&](auto kn) {
constexpr auto kIter = number<kn[number<0>{}]>{};
constexpr auto nIter = number<kn[number<1>{}]>{};
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
});
b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
});
// Prefill A(loopK)
@@ -870,44 +870,44 @@ defined(USING_MFMA_32x32x64) && defined(ENABLE_FP4) // mi350 fp4 32c 1*K1
store_tile(a_copy_lds_window_pong, a_block_tile_tmp);
// GEMM loopK-1
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
static_ford<sequence<KIterPerWarp, MIterPerWarp>>{}([&](auto km) {
constexpr auto kIter = number<km[number<0>{}]>{};
constexpr auto mIter = number<km[number<1>{}]>{};
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor,
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_ping(nIter)(kIter));
// warp GEMM
WG{}(c_warp_tensor,
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_ping(nIter)(kIter));
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
// preload next A from lds
if constexpr((kIter * MIterPerWarp + mIter) <
(KIterPerWarp * MIterPerWarp - m_preload))
{
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
a_warp_tensor(number<AwarpIter>{}) =
load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
}
// barrier
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
{
block_sync_lds();
}
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
// preload next A from lds
if constexpr((kIter * MIterPerWarp + mIter) <
(KIterPerWarp * MIterPerWarp - m_preload))
{
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
a_warp_tensor(number<AwarpIter>{}) =
load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
}
// barrier
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
{
block_sync_lds();
}
});
static_for<0, m_preload, 1>{}([&](auto loadIter) {
@@ -920,86 +920,86 @@ defined(USING_MFMA_32x32x64) && defined(ENABLE_FP4) // mi350 fp4 32c 1*K1
Last2ndHotLoopScheduler();
// GEMM loopK
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
static_ford<sequence<KIterPerWarp, MIterPerWarp>>{}([&](auto km) {
constexpr auto kIter = number<km[number<0>{}]>{};
constexpr auto mIter = number<km[number<1>{}]>{};
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor,
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_pong(nIter)(kIter));
// warp GEMM
WG{}(c_warp_tensor,
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_pong(nIter)(kIter));
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
if constexpr((kIter * MIterPerWarp + mIter) <
(KIterPerWarp * MIterPerWarp - m_preload))
{
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
a_warp_tensor(number<AwarpIter>{}) =
load_tile(a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
}
// barrier
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
{
block_sync_lds();
}
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
if constexpr((kIter * MIterPerWarp + mIter) <
(KIterPerWarp * MIterPerWarp - m_preload))
{
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
a_warp_tensor(number<AwarpIter>{}) =
load_tile(a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
}
// barrier
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
{
block_sync_lds();
}
});
LastHotLoopScheduler();
}
else if constexpr(TailNum == TailNumber::Odd)
{
// GEMM loopK
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
static_ford<sequence<KIterPerWarp, MIterPerWarp>>{}([&](auto km) {
constexpr auto kIter = number<km[number<0>{}]>{};
constexpr auto mIter = number<km[number<1>{}]>{};
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor,
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_ping(nIter)(kIter));
// warp GEMM
WG{}(c_warp_tensor,
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_ping(nIter)(kIter));
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
// preload next A from lds
if constexpr((kIter * MIterPerWarp + mIter) <
(KIterPerWarp * MIterPerWarp - m_preload))
{
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
a_warp_tensor(number<AwarpIter>{}) =
load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
}
// barrier
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
{
block_sync_lds();
}
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
// preload next A from lds
if constexpr((kIter * MIterPerWarp + mIter) <
(KIterPerWarp * MIterPerWarp - m_preload))
{
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
a_warp_tensor(number<AwarpIter>{}) =
load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
}
// barrier
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
{
block_sync_lds();
}
});
LastHotLoopScheduler();
}

View File

@@ -529,22 +529,22 @@ struct MoeFlatmmPipelineAGmemBGmemCRegV1
MIterPerWarp>
a_warp_windows_pong;
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp;
static_ford<sequence<MIterPerWarp, KIterPerWarp>>{}([&](auto mk) {
constexpr auto mIter = number<mk[number<0>{}]>{};
constexpr auto kIter = number<mk[number<1>{}]>{};
a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp;
move_tile_window(a_warp_windows_ping(mIter)(kIter),
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
});
move_tile_window(a_warp_windows_ping(mIter)(kIter),
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
});
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp;
static_ford<sequence<MIterPerWarp, KIterPerWarp>>{}([&](auto mk) {
constexpr auto mIter = number<mk[number<0>{}]>{};
constexpr auto kIter = number<mk[number<1>{}]>{};
a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp;
move_tile_window(a_warp_windows_pong(mIter)(kIter),
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
});
move_tile_window(a_warp_windows_pong(mIter)(kIter),
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
});
// Block GEMM
@@ -592,26 +592,26 @@ struct MoeFlatmmPipelineAGmemBGmemCRegV1
2;
// prefetch B
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
static_ford<sequence<NIterPerWarp, KIterPerWarp>>{}([&](auto nk) {
constexpr auto nIter = number<nk[number<0>{}]>{};
constexpr auto kIter = number<nk[number<1>{}]>{};
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
if constexpr(!IsGateUpMode)
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
if constexpr(!IsGateUpMode)
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
else
{
if constexpr(nIter % 2 == 0)
move_tile_window(
b_flat_dram_windows(nIter)(kIter),
{nIter / 2 * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
else
{
if constexpr(nIter % 2 == 0)
move_tile_window(
b_flat_dram_windows(nIter)(kIter),
{nIter / 2 * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
else
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter / 2 * NFlatPerBlockPerIter + up_weight_stride,
kIter * KFlatPerBlockPerIter});
}
b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
});
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter / 2 * NFlatPerBlockPerIter + up_weight_stride,
kIter * KFlatPerBlockPerIter});
}
b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
});
// move B window to next flat K
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
@@ -648,28 +648,27 @@ struct MoeFlatmmPipelineAGmemBGmemCRegV1
while(iCounter > 0)
{
// prefetch B(2i+1)
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
static_ford<sequence<KIterPerWarp, NIterPerWarp>>{}([&](auto kn) {
constexpr auto kIter = number<kn[number<0>{}]>{};
constexpr auto nIter = number<kn[number<1>{}]>{};
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
if constexpr(!IsGateUpMode)
if constexpr(!IsGateUpMode)
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
else
{
if constexpr(nIter % 2 == 0)
move_tile_window(
b_flat_dram_windows(nIter)(kIter),
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
{nIter / 2 * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
else
{
if constexpr(nIter % 2 == 0)
move_tile_window(
b_flat_dram_windows(nIter)(kIter),
{nIter / 2 * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
else
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter / 2 * NFlatPerBlockPerIter + up_weight_stride,
kIter * KFlatPerBlockPerIter});
}
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter / 2 * NFlatPerBlockPerIter + up_weight_stride,
kIter * KFlatPerBlockPerIter});
}
b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
});
b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
});
// Prefill A(2i+1)
@@ -682,44 +681,44 @@ struct MoeFlatmmPipelineAGmemBGmemCRegV1
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
// GEMM 2i
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
static_ford<sequence<KIterPerWarp, MIterPerWarp>>{}([&](auto km) {
constexpr auto kIter = number<km[number<0>{}]>{};
constexpr auto mIter = number<km[number<1>{}]>{};
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor,
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_ping(nIter)(kIter));
// warp GEMM
WG{}(c_warp_tensor,
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_ping(nIter)(kIter));
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
// preload next A from lds
if constexpr((kIter * MIterPerWarp + mIter) <
(KIterPerWarp * MIterPerWarp - m_preload))
{
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
a_warp_tensor(number<AwarpIter>{}) =
load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
}
// barrier
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
{
block_sync_lds();
}
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
// preload next A from lds
if constexpr((kIter * MIterPerWarp + mIter) <
(KIterPerWarp * MIterPerWarp - m_preload))
{
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
a_warp_tensor(number<AwarpIter>{}) =
load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
}
// barrier
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
{
block_sync_lds();
}
});
// move B window to next flat K
@@ -736,28 +735,27 @@ struct MoeFlatmmPipelineAGmemBGmemCRegV1
// Next K
// prefetch B(2i+2)
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
static_ford<sequence<KIterPerWarp, NIterPerWarp>>{}([&](auto kn) {
constexpr auto kIter = number<kn[number<0>{}]>{};
constexpr auto nIter = number<kn[number<1>{}]>{};
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
if constexpr(!IsGateUpMode)
if constexpr(!IsGateUpMode)
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
else
{
if constexpr(nIter % 2 == 0)
move_tile_window(
b_flat_dram_windows(nIter)(kIter),
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
{nIter / 2 * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
else
{
if constexpr(nIter % 2 == 0)
move_tile_window(
b_flat_dram_windows(nIter)(kIter),
{nIter / 2 * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
else
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter / 2 * NFlatPerBlockPerIter + up_weight_stride,
kIter * KFlatPerBlockPerIter});
}
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter / 2 * NFlatPerBlockPerIter + up_weight_stride,
kIter * KFlatPerBlockPerIter});
}
b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
});
b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
});
// Prefill A(2i+2)
@@ -770,43 +768,43 @@ struct MoeFlatmmPipelineAGmemBGmemCRegV1
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
// GEMM 2i+1
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
static_ford<sequence<KIterPerWarp, MIterPerWarp>>{}([&](auto km) {
constexpr auto kIter = number<km[number<0>{}]>{};
constexpr auto mIter = number<km[number<1>{}]>{};
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor,
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_pong(nIter)(kIter));
// warp GEMM
WG{}(c_warp_tensor,
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_pong(nIter)(kIter));
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
// preload next A from lds
if constexpr((kIter * MIterPerWarp + mIter) <
(KIterPerWarp * MIterPerWarp - m_preload))
{
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
a_warp_tensor(number<AwarpIter>{}) =
load_tile(a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
}
// barrier
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
{
block_sync_lds();
}
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
// preload next A from lds
if constexpr((kIter * MIterPerWarp + mIter) <
(KIterPerWarp * MIterPerWarp - m_preload))
{
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
a_warp_tensor(number<AwarpIter>{}) =
load_tile(a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
}
// barrier
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
{
block_sync_lds();
}
});
// move B window to next flat K
@@ -827,28 +825,27 @@ struct MoeFlatmmPipelineAGmemBGmemCRegV1
if constexpr(TailNum == TailNumber::Even)
{
// prefetch B(loopK)
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
static_ford<sequence<KIterPerWarp, NIterPerWarp>>{}([&](auto kn) {
constexpr auto kIter = number<kn[number<0>{}]>{};
constexpr auto nIter = number<kn[number<1>{}]>{};
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
if constexpr(!IsGateUpMode)
if constexpr(!IsGateUpMode)
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
else
{
if constexpr(nIter % 2 == 0)
move_tile_window(
b_flat_dram_windows(nIter)(kIter),
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
{nIter / 2 * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
else
{
if constexpr(nIter % 2 == 0)
move_tile_window(
b_flat_dram_windows(nIter)(kIter),
{nIter / 2 * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
else
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter / 2 * NFlatPerBlockPerIter + up_weight_stride,
kIter * KFlatPerBlockPerIter});
}
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter / 2 * NFlatPerBlockPerIter + up_weight_stride,
kIter * KFlatPerBlockPerIter});
}
b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
});
b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
});
// Prefill A(loopK)
@@ -856,44 +853,44 @@ struct MoeFlatmmPipelineAGmemBGmemCRegV1
store_tile(a_copy_lds_window_pong, a_block_tile_tmp);
// GEMM loopK-1
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
static_ford<sequence<KIterPerWarp, MIterPerWarp>>{}([&](auto km) {
constexpr auto kIter = number<km[number<0>{}]>{};
constexpr auto mIter = number<km[number<1>{}]>{};
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor,
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_ping(nIter)(kIter));
// warp GEMM
WG{}(c_warp_tensor,
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_ping(nIter)(kIter));
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
// preload next A from lds
if constexpr((kIter * MIterPerWarp + mIter) <
(KIterPerWarp * MIterPerWarp - m_preload))
{
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
a_warp_tensor(number<AwarpIter>{}) =
load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
}
// barrier
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
{
block_sync_lds();
}
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
// preload next A from lds
if constexpr((kIter * MIterPerWarp + mIter) <
(KIterPerWarp * MIterPerWarp - m_preload))
{
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
a_warp_tensor(number<AwarpIter>{}) =
load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
}
// barrier
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
{
block_sync_lds();
}
});
static_for<0, m_preload, 1>{}([&](auto loadIter) {
@@ -906,86 +903,86 @@ struct MoeFlatmmPipelineAGmemBGmemCRegV1
Last2ndHotLoopScheduler();
// GEMM loopK
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
static_ford<sequence<KIterPerWarp, MIterPerWarp>>{}([&](auto km) {
constexpr auto kIter = number<km[number<0>{}]>{};
constexpr auto mIter = number<km[number<1>{}]>{};
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor,
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_pong(nIter)(kIter));
// warp GEMM
WG{}(c_warp_tensor,
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_pong(nIter)(kIter));
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
if constexpr((kIter * MIterPerWarp + mIter) <
(KIterPerWarp * MIterPerWarp - m_preload))
{
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
a_warp_tensor(number<AwarpIter>{}) =
load_tile(a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
}
// barrier
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
{
block_sync_lds();
}
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
if constexpr((kIter * MIterPerWarp + mIter) <
(KIterPerWarp * MIterPerWarp - m_preload))
{
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
a_warp_tensor(number<AwarpIter>{}) =
load_tile(a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
}
// barrier
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
{
block_sync_lds();
}
});
LastHotLoopScheduler();
}
else if constexpr(TailNum == TailNumber::Odd)
{
// GEMM loopK
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
static_ford<sequence<KIterPerWarp, MIterPerWarp>>{}([&](auto km) {
constexpr auto kIter = number<km[number<0>{}]>{};
constexpr auto mIter = number<km[number<1>{}]>{};
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor,
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_ping(nIter)(kIter));
// warp GEMM
WG{}(c_warp_tensor,
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_ping(nIter)(kIter));
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
// preload next A from lds
if constexpr((kIter * MIterPerWarp + mIter) <
(KIterPerWarp * MIterPerWarp - m_preload))
{
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
a_warp_tensor(number<AwarpIter>{}) =
load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
}
// barrier
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
{
block_sync_lds();
}
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
// preload next A from lds
if constexpr((kIter * MIterPerWarp + mIter) <
(KIterPerWarp * MIterPerWarp - m_preload))
{
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
a_warp_tensor(number<AwarpIter>{}) =
load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
}
// barrier
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
{
block_sync_lds();
}
});
LastHotLoopScheduler();
}

View File

@@ -486,13 +486,13 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
auto c_block_tile = BlockFlatmm{}.MakeCBlockTile();
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
c_block_tile.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensors(mIter)(nIter).get_thread_buffer());
});
static_ford<sequence<MIterPerWarp, NIterPerWarp>>{}([&](auto mn) {
constexpr auto mIter = number<mn[number<0>{}]>{};
constexpr auto nIter = number<mn[number<1>{}]>{};
c_block_tile.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensors(mIter)(nIter).get_thread_buffer());
});
return c_block_tile;
}
@@ -643,24 +643,23 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
});
// prefetch Scale A
static_for<0, MPackIterPerWarp, 1>{}([&](auto impack) {
static_for<0, KPackIterPerWarp, 1>{}([&](auto ikpack) {
scale_a_tile_tensor_ping(impack)(ikpack) = load_tile_with_offset(
scale_a_dram_window,
static_ford<sequence<MPackIterPerWarp, KPackIterPerWarp>>{}([&](auto ii) {
constexpr auto impack = number<ii[number<0>{}]>{};
constexpr auto ikpack = number<ii[number<1>{}]>{};
scale_a_tile_tensor_ping(impack)(ikpack) =
load_tile_with_offset(scale_a_dram_window,
impack * scale_a_dram_step_m + ikpack * scale_a_dram_step_k);
});
impack * scale_a_dram_step_m + ikpack * scale_a_dram_step_k);
});
// move Scale A window to next K
move_tile_window(scale_a_dram_window, {0, kKPerBlock / (32 * KXdlPack)});
// prefetch Scale B
static_for<0, NPackIterPerWarp, 1>{}([&](auto inpack) {
static_for<0, KPackIterPerWarp, 1>{}([&](auto ikpack) {
scale_b_tile_tensor_ping(inpack)(ikpack) = load_tile_with_offset(
scale_b_dram_window,
inpack * scale_b_dram_step_n + ikpack * scale_b_dram_step_k);
});
static_ford<sequence<NPackIterPerWarp, KPackIterPerWarp>>{}([&](auto ii) {
constexpr auto inpack = number<ii[number<0>{}]>{};
constexpr auto ikpack = number<ii[number<1>{}]>{};
scale_b_tile_tensor_ping(inpack)(ikpack) = load_tile_with_offset(
scale_b_dram_window, inpack * scale_b_dram_step_n + ikpack * scale_b_dram_step_k);
});
// move Scale B window to next K
move_tile_window(scale_b_dram_window, {0, kKPerBlock / (32 * KXdlPack)});
@@ -698,34 +697,34 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
// MAIN LOOP
auto main_body_implx2 = [&]() mutable {
// prefetch B(2i+1)
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
b_warp_tensor_pong(nIter)(kIter) = load_tile_with_offset(
b_flat_dram_window,
b_flat_dram_offsets(nIter) + kIter * KFlatBytesPerBlockPerIter);
static_ford<sequence<KIterPerWarp, NIterPerWarp>>{}([&](auto kn) {
constexpr auto kIter = number<kn[number<0>{}]>{};
constexpr auto nIter = number<kn[number<1>{}]>{};
b_warp_tensor_pong(nIter)(kIter) = load_tile_with_offset(
b_flat_dram_window,
b_flat_dram_offsets(nIter) + kIter * KFlatBytesPerBlockPerIter);
// move B window to next flat K
if constexpr(kIter == KIterPerWarp - 1)
b_flat_dram_offsets(nIter) += b_flat_dram_window.get_load_offset(
tuple<number<0>, number<KIterPerWarp * KFlatBytesPerBlockPerIter>>{});
});
// move B window to next flat K
if constexpr(kIter == KIterPerWarp - 1)
b_flat_dram_offsets(nIter) += b_flat_dram_window.get_load_offset(
tuple<number<0>, number<KIterPerWarp * KFlatBytesPerBlockPerIter>>{});
});
// prefetch Scale A and Scale B (2i+1)
static_for<0, KPackIterPerWarp, 1>{}([&](auto ikpack) {
static_for<0, MPackIterPerWarp, 1>{}([&](auto impack) {
scale_a_tile_tensor_pong(impack)(ikpack) = load_tile_with_offset(
scale_a_dram_window,
impack * scale_a_dram_step_m + ikpack * scale_a_dram_step_k);
});
static_ford<sequence<KPackIterPerWarp, MPackIterPerWarp>>{}([&](auto ii) {
constexpr auto ikpack = number<ii[number<0>{}]>{};
constexpr auto impack = number<ii[number<1>{}]>{};
scale_a_tile_tensor_pong(impack)(ikpack) = load_tile_with_offset(
scale_a_dram_window,
impack * scale_a_dram_step_m + ikpack * scale_a_dram_step_k);
});
static_for<0, KPackIterPerWarp, 1>{}([&](auto ikpack) {
static_for<0, NPackIterPerWarp, 1>{}([&](auto inpack) {
scale_b_tile_tensor_pong(inpack)(ikpack) = load_tile_with_offset(
scale_b_dram_window,
inpack * scale_b_dram_step_n + ikpack * scale_b_dram_step_k);
});
static_ford<sequence<KPackIterPerWarp, NPackIterPerWarp>>{}([&](auto ii) {
constexpr auto ikpack = number<ii[number<0>{}]>{};
constexpr auto inpack = number<ii[number<1>{}]>{};
scale_b_tile_tensor_pong(inpack)(ikpack) = load_tile_with_offset(
scale_b_dram_window,
inpack * scale_b_dram_step_n + ikpack * scale_b_dram_step_k);
});
// GEMM 2i
@@ -788,34 +787,34 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
////////////////////////////// Next K //////////////////////////////
// prefetch B(2i+2)
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
b_warp_tensor_ping(nIter)(kIter) = load_tile_with_offset(
b_flat_dram_window,
b_flat_dram_offsets(nIter) + kIter * KFlatBytesPerBlockPerIter);
static_ford<sequence<KIterPerWarp, NIterPerWarp>>{}([&](auto kn) {
constexpr auto kIter = number<kn[number<0>{}]>{};
constexpr auto nIter = number<kn[number<1>{}]>{};
b_warp_tensor_ping(nIter)(kIter) = load_tile_with_offset(
b_flat_dram_window,
b_flat_dram_offsets(nIter) + kIter * KFlatBytesPerBlockPerIter);
// move B window to next flat K
if constexpr(kIter == KIterPerWarp - 1)
b_flat_dram_offsets(nIter) += b_flat_dram_window.get_load_offset(
tuple<number<0>, number<KIterPerWarp * KFlatBytesPerBlockPerIter>>{});
});
// move B window to next flat K
if constexpr(kIter == KIterPerWarp - 1)
b_flat_dram_offsets(nIter) += b_flat_dram_window.get_load_offset(
tuple<number<0>, number<KIterPerWarp * KFlatBytesPerBlockPerIter>>{});
});
// prefetch Scale A and Scale B (2i+2)
static_for<0, KPackIterPerWarp, 1>{}([&](auto ikpack) {
static_for<0, MPackIterPerWarp, 1>{}([&](auto impack) {
scale_a_tile_tensor_ping(impack)(ikpack) = load_tile_with_offset(
scale_a_dram_window,
impack * scale_a_dram_step_m + ikpack * scale_a_dram_step_k);
});
static_ford<sequence<KPackIterPerWarp, MPackIterPerWarp>>{}([&](auto ii) {
constexpr auto ikpack = number<ii[number<0>{}]>{};
constexpr auto impack = number<ii[number<1>{}]>{};
scale_a_tile_tensor_ping(impack)(ikpack) = load_tile_with_offset(
scale_a_dram_window,
impack * scale_a_dram_step_m + ikpack * scale_a_dram_step_k);
});
static_for<0, KPackIterPerWarp, 1>{}([&](auto ikpack) {
static_for<0, NPackIterPerWarp, 1>{}([&](auto inpack) {
scale_b_tile_tensor_ping(inpack)(ikpack) = load_tile_with_offset(
scale_b_dram_window,
inpack * scale_b_dram_step_n + ikpack * scale_b_dram_step_k);
});
static_ford<sequence<KPackIterPerWarp, NPackIterPerWarp>>{}([&](auto ii) {
constexpr auto ikpack = number<ii[number<0>{}]>{};
constexpr auto inpack = number<ii[number<1>{}]>{};
scale_b_tile_tensor_ping(inpack)(ikpack) = load_tile_with_offset(
scale_b_dram_window,
inpack * scale_b_dram_step_n + ikpack * scale_b_dram_step_k);
});
// GEMM 2i+1
@@ -888,28 +887,28 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
if constexpr(TailNum == TailNumber::Even)
{
// prefetch B(loopK)
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
b_warp_tensor_pong(nIter)(kIter) = load_tile_with_offset(
b_flat_dram_window,
b_flat_dram_offsets(nIter) + kIter * KFlatBytesPerBlockPerIter);
});
static_ford<sequence<KIterPerWarp, NIterPerWarp>>{}([&](auto kn) {
constexpr auto kIter = number<kn[number<0>{}]>{};
constexpr auto nIter = number<kn[number<1>{}]>{};
b_warp_tensor_pong(nIter)(kIter) = load_tile_with_offset(
b_flat_dram_window,
b_flat_dram_offsets(nIter) + kIter * KFlatBytesPerBlockPerIter);
});
// prefetch Scale A and Scale B (2i+1)
static_for<0, MPackIterPerWarp, 1>{}([&](auto impack) {
static_for<0, KPackIterPerWarp, 1>{}([&](auto ikpack) {
scale_a_tile_tensor_pong(impack)(ikpack) = load_tile_with_offset(
scale_a_dram_window,
impack * scale_a_dram_step_m + ikpack * scale_a_dram_step_k);
});
static_ford<sequence<MPackIterPerWarp, KPackIterPerWarp>>{}([&](auto ii) {
constexpr auto impack = number<ii[number<0>{}]>{};
constexpr auto ikpack = number<ii[number<1>{}]>{};
scale_a_tile_tensor_pong(impack)(ikpack) = load_tile_with_offset(
scale_a_dram_window,
impack * scale_a_dram_step_m + ikpack * scale_a_dram_step_k);
});
static_for<0, NPackIterPerWarp, 1>{}([&](auto inpack) {
static_for<0, KPackIterPerWarp, 1>{}([&](auto ikpack) {
scale_b_tile_tensor_pong(inpack)(ikpack) = load_tile_with_offset(
scale_b_dram_window,
inpack * scale_b_dram_step_n + ikpack * scale_b_dram_step_k);
});
static_ford<sequence<NPackIterPerWarp, KPackIterPerWarp>>{}([&](auto ii) {
constexpr auto inpack = number<ii[number<0>{}]>{};
constexpr auto ikpack = number<ii[number<1>{}]>{};
scale_b_tile_tensor_pong(inpack)(ikpack) = load_tile_with_offset(
scale_b_dram_window,
inpack * scale_b_dram_step_n + ikpack * scale_b_dram_step_k);
});
// GEMM loopK-1

View File

@@ -1706,22 +1706,22 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
p_warp_tensor.get_thread_buffer() = p_in.get_y_sliced_thread_data(
merge_sequences(sequence<kIter, mIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
static_ford<sequence<KIterPerWarp, MIterPerWarp>>{}([&](auto km) {
constexpr auto kIter = number<km[number<0>{}]>{};
constexpr auto mIter = number<km[number<1>{}]>{};
p_warp_tensor.get_thread_buffer() = p_in.get_y_sliced_thread_data(
merge_sequences(sequence<kIter, mIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
#if defined(__gfx11__)
PermuteWarpGemmCToA(pt_warp_tensor, p_warp_tensor);
PermuteWarpGemmCToA(pt_warp_tensor, p_warp_tensor);
#else
pt_warp_tensor.get_thread_buffer() = p_warp_tensor.get_thread_buffer();
pt_warp_tensor.get_thread_buffer() = p_warp_tensor.get_thread_buffer();
#endif
pt_out.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths),
pt_warp_tensor.get_thread_buffer());
});
pt_out.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths),
pt_warp_tensor.get_thread_buffer());
});
}
else
@@ -1763,22 +1763,22 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
ds_warp_tensor.get_thread_buffer() = ds_in.get_y_sliced_thread_data(
merge_sequences(sequence<kIter, mIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
static_ford<sequence<KIterPerWarp, MIterPerWarp>>{}([&](auto km) {
constexpr auto kIter = number<km[number<0>{}]>{};
constexpr auto mIter = number<km[number<1>{}]>{};
ds_warp_tensor.get_thread_buffer() = ds_in.get_y_sliced_thread_data(
merge_sequences(sequence<kIter, mIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
#if defined(__gfx11__)
PermuteWarpGemmCToA(dst_warp_tensor, ds_warp_tensor);
PermuteWarpGemmCToA(dst_warp_tensor, ds_warp_tensor);
#else
dst_warp_tensor.get_thread_buffer() = ds_warp_tensor.get_thread_buffer();
dst_warp_tensor.get_thread_buffer() = ds_warp_tensor.get_thread_buffer();
#endif
dst_out.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths),
dst_warp_tensor.get_thread_buffer());
});
dst_out.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths),
dst_warp_tensor.get_thread_buffer());
});
}
else

View File

@@ -213,38 +213,38 @@ struct BlockGemmARegBRegCRegV1
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
// hot loop:
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A Block window
AWarpTensor a_warp_tensor;
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
static_ford<sequence<KIterPerWarp, MIterPerWarp>>{}([&](auto km) {
constexpr auto kIter = number<km[number<0>{}]>{};
constexpr auto mIter = number<km[number<1>{}]>{};
// read A warp tensor from A Block window
AWarpTensor a_warp_tensor;
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B block tensor
BWarpTensor b_warp_tensor;
b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B block tensor
BWarpTensor b_warp_tensor;
b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
// read C warp tensor from C block tensor
using c_iter_idx = std::
conditional_t<TransposeC, sequence<nIter, mIter>, sequence<mIter, nIter>>;
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(c_iter_idx{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// read C warp tensor from C block tensor
using c_iter_idx =
std::conditional_t<TransposeC, sequence<nIter, mIter>, sequence<mIter, nIter>>;
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(c_iter_idx{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
// warp GEMM
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(c_iter_idx{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(c_iter_idx{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
});
}
@@ -323,73 +323,69 @@ struct BlockGemmARegBRegCRegV1
// hot loop with MX scaling and pre-packed int32_t scales:
// Outer loops iterate over pack groups (scale tile indices)
static_for<0, KPackIterPerWarp, 1>{}([&](auto ikpack) {
static_for<0, MPackIterPerWarp, 1>{}([&](auto impack) {
// Get pre-packed int32_t A scale (already contains MXdlPack*KXdlPack e8m0_t)
auto scale_a_slice = scale_a_tensor.get_y_sliced_thread_data(
sequence<ikpack, impack, 0>{}, sequence<1, 1, 1>{});
const int32_t a_scale_packed = bit_cast<int32_t>(scale_a_slice[number<0>{}]);
static_ford<sequence<KPackIterPerWarp, MPackIterPerWarp>>{}([&](auto ii) {
constexpr auto ikpack = number<ii[number<0>{}]>{};
constexpr auto impack = number<ii[number<1>{}]>{};
// Get pre-packed int32_t A scale (already contains MXdlPack*KXdlPack e8m0_t)
auto scale_a_slice = scale_a_tensor.get_y_sliced_thread_data(
sequence<ikpack, impack, 0>{}, sequence<1, 1, 1>{});
const int32_t a_scale_packed = bit_cast<int32_t>(scale_a_slice[number<0>{}]);
static_for<0, NPackIterPerWarp, 1>{}([&](auto inpack) {
// Get pre-packed int32_t B scale
auto scale_b_slice = scale_b_tensor.get_y_sliced_thread_data(
sequence<ikpack, inpack, 0>{}, sequence<1, 1, 1>{});
const int32_t b_scale_packed = bit_cast<int32_t>(scale_b_slice[number<0>{}]);
static_for<0, NPackIterPerWarp, 1>{}([&](auto inpack) {
// Get pre-packed int32_t B scale
auto scale_b_slice = scale_b_tensor.get_y_sliced_thread_data(
sequence<ikpack, inpack, 0>{}, sequence<1, 1, 1>{});
const int32_t b_scale_packed = bit_cast<int32_t>(scale_b_slice[number<0>{}]);
// Inner loops: issue MFMAs within the pack group using OpSel
static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
constexpr auto kIter = ikpack * KXdlPack + ikxdl;
constexpr auto mIter = impack * MXdlPack + imxdl;
// Inner loops: issue MFMAs within the pack group using OpSel
static_ford<sequence<KXdlPack, MXdlPack>>{}([&](auto jj) {
constexpr auto ikxdl = number<jj[number<0>{}]>{};
constexpr auto imxdl = number<jj[number<1>{}]>{};
constexpr auto kIter = ikpack * KXdlPack + ikxdl;
constexpr auto mIter = impack * MXdlPack + imxdl;
// read A warp tensor from A block tensor
AWarpTensor a_warp_tensor;
a_warp_tensor.get_thread_buffer() =
a_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
// read A warp tensor from A block tensor
AWarpTensor a_warp_tensor;
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
// OpSel for A: selects byte within packed int32_t
constexpr index_t kOpSelA = ikxdl * MXdlPack + imxdl;
// OpSel for A: selects byte within packed int32_t
constexpr index_t kOpSelA = ikxdl * MXdlPack + imxdl;
static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
constexpr auto nIter = inpack * NXdlPack + inxdl;
static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
constexpr auto nIter = inpack * NXdlPack + inxdl;
// read B warp tensor from B block tensor
BWarpTensor b_warp_tensor;
b_warp_tensor.get_thread_buffer() =
b_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<nIter, kIter>{},
b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
// read B warp tensor from B block tensor
BWarpTensor b_warp_tensor;
b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
// OpSel for B: selects byte within packed int32_t
constexpr index_t kOpSelB = ikxdl * NXdlPack + inxdl;
// OpSel for B: selects byte within packed int32_t
constexpr index_t kOpSelB = ikxdl * NXdlPack + inxdl;
// read C warp tensor from C block tensor
using c_iter_idx = std::conditional_t<TransposeC,
sequence<nIter, mIter>,
sequence<mIter, nIter>>;
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() =
c_block_tensor.get_y_sliced_thread_data(
merge_sequences(c_iter_idx{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// read C warp tensor from C block tensor
using c_iter_idx = std::conditional_t<TransposeC,
sequence<nIter, mIter>,
sequence<mIter, nIter>>;
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(c_iter_idx{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM with MX scaling using pre-packed scale and OpSel
WarpGemm{}.template operator()<kOpSelA, kOpSelB>(c_warp_tensor,
a_warp_tensor,
b_warp_tensor,
a_scale_packed,
b_scale_packed);
// warp GEMM with MX scaling using pre-packed scale and OpSel
WarpGemm{}.template operator()<kOpSelA, kOpSelB>(c_warp_tensor,
a_warp_tensor,
b_warp_tensor,
a_scale_packed,
b_scale_packed);
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(c_iter_idx{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
});
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(c_iter_idx{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
});
});

View File

@@ -250,74 +250,74 @@ struct BlockGemmARegBRegCRegV2
// hot loop:
if constexpr(BlockGemmLoopOrder == GemmLoopOrder::KMN)
{
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A Block window
AWarpTensor a_warp_tensor;
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<kIter, mIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
static_ford<sequence<KIterPerWarp, MIterPerWarp>>{}([&](auto km) {
constexpr auto kIter = number<km[number<0>{}]>{};
constexpr auto mIter = number<km[number<1>{}]>{};
// read A warp tensor from A Block window
AWarpTensor a_warp_tensor;
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<kIter, mIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B block tensor
BWarpTensor b_warp_tensor;
b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<kIter, nIter>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B block tensor
BWarpTensor b_warp_tensor;
b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<kIter, nIter>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
// warp GEMM
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
});
}
else if constexpr(BlockGemmLoopOrder == GemmLoopOrder::MNK)
{
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
// read A warp tensor from A Block window
AWarpTensor a_warp_tensor;
static_ford<sequence<MIterPerWarp, NIterPerWarp, KIterPerWarp>>{}([&](auto mnk) {
constexpr auto mIter = number<mnk[number<0>{}]>{};
constexpr auto nIter = number<mnk[number<1>{}]>{};
constexpr auto kIter = number<mnk[number<2>{}]>{};
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
// read A warp tensor from A Block window
AWarpTensor a_warp_tensor;
// read B warp tensor from B block tensor
BWarpTensor b_warp_tensor;
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
// read B warp tensor from B block tensor
BWarpTensor b_warp_tensor;
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
// warp GEMM
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
});
// warp GEMM
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
}
}

View File

@@ -109,13 +109,13 @@ struct BlockGemmARegBSmemCRegOneWarpV1
NIterPerWarp>
b_warp_windows;
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
static_ford<sequence<NIterPerWarp, KIterPerWarp>>{}([&](auto nk) {
constexpr auto nIter = number<nk[number<0>{}]>{};
constexpr auto kIter = number<nk[number<1>{}]>{};
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
move_tile_window(b_warp_windows(nIter)(kIter),
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
});
move_tile_window(b_warp_windows(nIter)(kIter),
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
});
#endif
@@ -141,35 +141,35 @@ struct BlockGemmARegBSmemCRegOneWarpV1
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
// hot loop:
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A block tensor
AWarpTensor a_warp_tensor;
static_ford<sequence<KIterPerWarp, MIterPerWarp>>{}([&](auto km) {
constexpr auto kIter = number<km[number<0>{}]>{};
constexpr auto mIter = number<km[number<1>{}]>{};
// read A warp tensor from A block tensor
AWarpTensor a_warp_tensor;
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B Block window
const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B Block window
const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
});
}

View File

@@ -116,13 +116,13 @@ struct BlockGemmARegBSmemCRegV1
NIterPerWarp>
b_warp_windows;
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
static_ford<sequence<NIterPerWarp, KIterPerWarp>>{}([&](auto nk) {
constexpr auto nIter = number<nk[number<0>{}]>{};
constexpr auto kIter = number<nk[number<1>{}]>{};
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
move_tile_window(b_warp_windows(nIter)(kIter),
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
});
move_tile_window(b_warp_windows(nIter)(kIter),
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
});
#endif
@@ -148,35 +148,35 @@ struct BlockGemmARegBSmemCRegV1
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
// hot loop:
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A block tensor
AWarpTensor a_warp_tensor;
static_ford<sequence<KIterPerWarp, MIterPerWarp>>{}([&](auto km) {
constexpr auto kIter = number<km[number<0>{}]>{};
constexpr auto mIter = number<km[number<1>{}]>{};
// read A warp tensor from A block tensor
AWarpTensor a_warp_tensor;
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B Block window
const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B Block window
const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
});
}

View File

@@ -103,13 +103,13 @@ struct BlockGemmARegBSmemCRegV2
NIterPerWarp>
b_warp_windows;
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
static_ford<sequence<NIterPerWarp, KIterPerWarp>>{}([&](auto nk) {
constexpr auto nIter = number<nk[number<0>{}]>{};
constexpr auto kIter = number<nk[number<1>{}]>{};
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
move_tile_window(b_warp_windows(nIter)(kIter),
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
});
move_tile_window(b_warp_windows(nIter)(kIter),
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
});
#endif
@@ -135,36 +135,36 @@ struct BlockGemmARegBSmemCRegV2
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
// hot loop:
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B Block window
const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
static_ford<sequence<KIterPerWarp, NIterPerWarp>>{}([&](auto kn) {
constexpr auto kIter = number<kn[number<0>{}]>{};
constexpr auto nIter = number<kn[number<1>{}]>{};
// read B warp tensor from B Block window
const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A block tensor
AWarpTensor a_warp_tensor;
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A block tensor
AWarpTensor a_warp_tensor;
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
// WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor_array[nIter]);
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
// WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor_array[nIter]);
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
});
}

View File

@@ -90,13 +90,13 @@ struct BlockGemmARegBSmemCRegV2R1
NIterPerWarp>
b_warp_windows;
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
static_ford<sequence<NIterPerWarp, KIterPerWarp>>{}([&](auto nk) {
constexpr auto nIter = number<nk[number<0>{}]>{};
constexpr auto kIter = number<nk[number<1>{}]>{};
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
move_tile_window(b_warp_windows(nIter)(kIter),
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
});
move_tile_window(b_warp_windows(nIter)(kIter),
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
});
// check C-block-distribution
@@ -126,43 +126,43 @@ struct BlockGemmARegBSmemCRegV2R1
NIterPerWarp>
b_warp_tensors;
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
b_warp_tensors(nIter)(kIter) = load_tile(b_warp_windows(nIter)(kIter));
});
static_ford<sequence<KIterPerWarp, NIterPerWarp>>{}([&](auto kn) {
constexpr auto kIter = number<kn[number<0>{}]>{};
constexpr auto nIter = number<kn[number<1>{}]>{};
b_warp_tensors(nIter)(kIter) = load_tile(b_warp_windows(nIter)(kIter));
});
// hot loop:
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B Block window
const auto b_warp_tensor = b_warp_tensors(nIter)(kIter);
static_ford<sequence<KIterPerWarp, NIterPerWarp>>{}([&](auto kn) {
constexpr auto kIter = number<kn[number<0>{}]>{};
constexpr auto nIter = number<kn[number<1>{}]>{};
// read B warp tensor from B Block window
const auto b_warp_tensor = b_warp_tensors(nIter)(kIter);
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A block tensor
AWarpTensor a_warp_tensor;
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A block tensor
AWarpTensor a_warp_tensor;
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
// WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor_array[nIter]);
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
// WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor_array[nIter]);
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
});

View File

@@ -116,13 +116,13 @@ struct BlockGemmASmemBRegCRegV1
MIterPerWarp>
a_warp_windows;
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
a_warp_windows(mIter)(kIter) = a_warp_window_tmp;
static_ford<sequence<MIterPerWarp, KIterPerWarp>>{}([&](auto mk) {
constexpr auto mIter = number<mk[number<0>{}]>{};
constexpr auto kIter = number<mk[number<1>{}]>{};
a_warp_windows(mIter)(kIter) = a_warp_window_tmp;
move_tile_window(a_warp_windows(mIter)(kIter),
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
});
move_tile_window(a_warp_windows(mIter)(kIter),
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
});
#endif
@@ -148,34 +148,34 @@ struct BlockGemmASmemBRegCRegV1
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
// hot loop:
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A Block window
const auto a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter));
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B block tensor
BWarpTensor b_warp_tensor;
static_ford<sequence<KIterPerWarp, MIterPerWarp>>{}([&](auto km) {
constexpr auto kIter = number<km[number<0>{}]>{};
constexpr auto mIter = number<km[number<1>{}]>{};
// read A warp tensor from A Block window
const auto a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter));
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B block tensor
BWarpTensor b_warp_tensor;
b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
});
}

View File

@@ -85,13 +85,13 @@ struct BlockGemmASmemBSmemCRegV1
MIterPerWarp>
a_warp_windows;
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
a_warp_windows(mIter)(kIter) = a_warp_window_tmp;
static_ford<sequence<MIterPerWarp, KIterPerWarp>>{}([&](auto mk) {
constexpr auto mIter = number<mk[number<0>{}]>{};
constexpr auto kIter = number<mk[number<1>{}]>{};
a_warp_windows(mIter)(kIter) = a_warp_window_tmp;
move_tile_window(a_warp_windows(mIter)(kIter),
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
});
move_tile_window(a_warp_windows(mIter)(kIter),
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
});
#endif
@@ -120,13 +120,13 @@ struct BlockGemmASmemBSmemCRegV1
NIterPerWarp>
b_warp_windows;
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
static_ford<sequence<NIterPerWarp, KIterPerWarp>>{}([&](auto nk) {
constexpr auto nIter = number<nk[number<0>{}]>{};
constexpr auto kIter = number<nk[number<1>{}]>{};
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
move_tile_window(b_warp_windows(nIter)(kIter),
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
});
move_tile_window(b_warp_windows(nIter)(kIter),
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
});
#endif
@@ -138,31 +138,31 @@ struct BlockGemmASmemBSmemCRegV1
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
// hot loop:
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A block window
const auto a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter));
static_ford<sequence<KIterPerWarp, MIterPerWarp>>{}([&](auto km) {
constexpr auto kIter = number<km[number<0>{}]>{};
constexpr auto mIter = number<km[number<1>{}]>{};
// read A warp tensor from A block window
const auto a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter));
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B Block window
const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B Block window
const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
});
}

View File

@@ -165,61 +165,60 @@ struct BlockGemmMxARegBSmemCRegV1
uniform_sequence_gen_t<BScaleWarpDstr::NDimY, 0>{};
// hot loop:
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
auto b_warp_window = b_warp_window_tmp;
move_tile_window(
b_warp_window,
{nIter * (NPerBlock / NIterPerWarp), kIter * (KPerBlock / KIterPerWarp)});
// read B warp tensor from B Block window
const auto b_warp_tensor = load_tile(b_warp_window);
static_ford<sequence<KIterPerWarp, NIterPerWarp>>{}([&](auto kn) {
constexpr auto kIter = number<kn[number<0>{}]>{};
constexpr auto nIter = number<kn[number<1>{}]>{};
auto b_warp_window = b_warp_window_tmp;
move_tile_window(
b_warp_window,
{nIter * (NPerBlock / NIterPerWarp), kIter * (KPerBlock / KIterPerWarp)});
// read B warp tensor from B Block window
const auto b_warp_tensor = load_tile(b_warp_window);
BScaleWarpTensor b_scale_warp_tensor;
BScaleWarpTensor b_scale_warp_tensor;
b_scale_warp_tensor.get_thread_buffer() =
b_scale_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<nIter / NIterPack, nIter % NIterPack, kIter>{},
b_scale_warp_y_index_zeros),
merge_sequences(sequence<1, 1, 1>{}, b_scale_warp_y_lengths));
b_scale_warp_tensor.get_thread_buffer() = b_scale_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<nIter / NIterPack, nIter % NIterPack, kIter>{},
b_scale_warp_y_index_zeros),
merge_sequences(sequence<1, 1, 1>{}, b_scale_warp_y_lengths));
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A block tensor
AWarpTensor a_warp_tensor;
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A block tensor
AWarpTensor a_warp_tensor;
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
AScaleWarpTensor a_scale_warp_tensor;
AScaleWarpTensor a_scale_warp_tensor;
a_scale_warp_tensor.get_thread_buffer() =
a_scale_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_scale_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_scale_warp_y_lengths));
a_scale_warp_tensor.get_thread_buffer() =
a_scale_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_scale_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_scale_warp_y_lengths));
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter / NIterPack, nIter % NIterPack>{},
c_warp_y_index_zeros),
merge_sequences(sequence<1, 1, 1>{}, c_warp_y_lengths));
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter / NIterPack, nIter % NIterPack>{},
c_warp_y_index_zeros),
merge_sequences(sequence<1, 1, 1>{}, c_warp_y_lengths));
// warp GEMM
WarpGemm{}.template operator()<0, 0>(
c_warp_tensor,
a_warp_tensor,
b_warp_tensor,
int32_t(a_scale_warp_tensor.get_thread_buffer()[0]),
int32_t(b_scale_warp_tensor.get_thread_buffer()[0]));
// warp GEMM
WarpGemm{}.template operator()<0, 0>(
c_warp_tensor,
a_warp_tensor,
b_warp_tensor,
int32_t(a_scale_warp_tensor.get_thread_buffer()[0]),
int32_t(b_scale_warp_tensor.get_thread_buffer()[0]));
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter / NIterPack, nIter % NIterPack>{},
c_warp_y_index_zeros),
merge_sequences(sequence<1, 1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter / NIterPack, nIter % NIterPack>{},
c_warp_y_index_zeros),
merge_sequences(sequence<1, 1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
});
}

View File

@@ -239,39 +239,39 @@ struct BlockUniversalGemmAsBsCr
"C block tensor data type!");
// hot loop:
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A block tensor
AWarpTensor a_warp_tensor;
static_ford<sequence<KIterPerWarp, MIterPerWarp>>{}([&](auto km) {
constexpr auto kIter = number<km[number<0>{}]>{};
constexpr auto mIter = number<km[number<1>{}]>{};
// read A warp tensor from A block tensor
AWarpTensor a_warp_tensor;
a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B block tensor
BWarpTensor b_warp_tensor;
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B block tensor
BWarpTensor b_warp_tensor;
b_warp_tensor.get_thread_buffer() = b_warp_tile_.get_y_sliced_thread_data(
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
b_warp_tensor.get_thread_buffer() = b_warp_tile_.get_y_sliced_thread_data(
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
// warp GEMM
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
});
}
@@ -392,63 +392,59 @@ struct BlockUniversalGemmAsBsCr
0); // Prevents instruction reordering across this boundary
}
static_for<0, KInnerLoopIter, 1>{}([&](auto kInnerIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A block tensor
AWarpTensor a_warp_tensor;
static_ford<sequence<KInnerLoopIter, MIterPerWarp>>{}([&](auto km) {
constexpr auto kInnerIter = number<km[number<0>{}]>{};
constexpr auto mIter = number<km[number<1>{}]>{};
// read A warp tensor from A block tensor
AWarpTensor a_warp_tensor;
a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kInnerIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B block tensor
BWarpTensor b_warp_tensor;
a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kInnerIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B block tensor
BWarpTensor b_warp_tensor;
b_warp_tensor.get_thread_buffer() =
b_warp_tile_.get_y_sliced_thread_data(
merge_sequences(sequence<nIter, kInnerIter>{},
b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
// read C warp tensor from C block tensor-
CWarpTensor c_warp_tensor;
b_warp_tensor.get_thread_buffer() = b_warp_tile_.get_y_sliced_thread_data(
merge_sequences(sequence<nIter, kInnerIter>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
// read C warp tensor from C block tensor-
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() =
c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// The block_sync_lds() here performs double duty:
// A) safeguard against data hazard because barrier from
// blockwise_gemm is moved here B) reduce VMEM FIFO congestion
// by applying small delays to different wavefronts It is
// performed near the end of MAC cluster to minimize lgkmcnt
// penalty
if constexpr(kIter.value == KRepeat - 1 &&
kInnerIter.value == KInnerLoopIter - 1 &&
mIter.value == MIterPerWarp - 1 &&
nIter.value == NIterPerWarp - 1)
{
__builtin_amdgcn_sched_barrier(0);
block_sync_lds();
__builtin_amdgcn_sched_barrier(0);
}
// warp GEMM
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
// The block_sync_lds() here performs double duty:
// A) safeguard against data hazard because barrier from
// blockwise_gemm is moved here B) reduce VMEM FIFO congestion
// by applying small delays to different wavefronts It is
// performed near the end of MAC cluster to minimize lgkmcnt
// penalty
if constexpr(kIter.value == KRepeat - 1 &&
kInnerIter.value == KInnerLoopIter - 1 &&
mIter.value == MIterPerWarp - 1 &&
nIter.value == NIterPerWarp - 1)
{
__builtin_amdgcn_sched_barrier(0);
block_sync_lds();
__builtin_amdgcn_sched_barrier(0);
}
// warp GEMM
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
if constexpr(kInnerIter.value == 0 && mIter.value == 0 &&
nIter.value == 0)
{
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_setprio(1);
__builtin_amdgcn_sched_barrier(0);
}
});
if constexpr(kInnerIter.value == 0 && mIter.value == 0 && nIter.value == 0)
{
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_setprio(1);
__builtin_amdgcn_sched_barrier(0);
}
});
});

View File

@@ -156,55 +156,54 @@ struct BlockWeightPreshuffleASmemBRegCReg
uniform_sequence_gen_t<BFlatDistribution::NDimY, 0>{};
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read C warp tensor from C block tensor
BWarpTensor b_warp_tensor;
CWarpTensor c_warp_tensor;
static_ford<sequence<KIterPerWarp, MIterPerWarp>>{}([&](auto km) {
constexpr auto kIter = number<km[number<0>{}]>{};
constexpr auto mIter = number<km[number<1>{}]>{};
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read C warp tensor from C block tensor
BWarpTensor b_warp_tensor;
CWarpTensor c_warp_tensor;
b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<nIter, kIter>{},
typename sequence_split<decltype(b_block_y_index_zeros),
2>::right_type{}),
merge_sequences(
sequence<1, 1>{},
typename sequence_split<decltype(b_block_y_lengths), 2>::right_type{}));
b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data(
merge_sequences(
sequence<nIter, kIter>{},
typename sequence_split<decltype(b_block_y_index_zeros), 2>::right_type{}),
merge_sequences(
sequence<1, 1>{},
typename sequence_split<decltype(b_block_y_lengths), 2>::right_type{}));
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WarpGemm{}(
c_warp_tensor, preloaded_a_warp_tensor(number<AwarpIter>{}), b_warp_tensor);
// warp GEMM
WarpGemm{}(
c_warp_tensor, preloaded_a_warp_tensor(number<AwarpIter>{}), b_warp_tensor);
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
__builtin_amdgcn_sched_barrier(0x7F6);
});
// preload next A from lds
if constexpr((kIter * MIterPerWarp + mIter) <
(KIterPerWarp * MIterPerWarp - m_preload))
{
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
load_tile(preloaded_a_warp_tensor(number<AwarpIter>{}),
a_load_windows[number<AkIter>{}][number<AmIter>{}]);
}
// barrier
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
{
block_sync_lds();
}
__builtin_amdgcn_sched_barrier(0x7F6);
});
// preload next A from lds
if constexpr((kIter * MIterPerWarp + mIter) < (KIterPerWarp * MIterPerWarp - m_preload))
{
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
load_tile(preloaded_a_warp_tensor(number<AwarpIter>{}),
a_load_windows[number<AkIter>{}][number<AmIter>{}]);
}
// barrier
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
{
block_sync_lds();
}
});
}
};

View File

@@ -88,28 +88,28 @@ struct BlockWeightPreshuffleASmemBSmemCRegV1
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
// hot loop:
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A block window
const auto a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter));
static_ford<sequence<KIterPerWarp, MIterPerWarp>>{}([&](auto km) {
constexpr auto kIter = number<km[number<0>{}]>{};
constexpr auto mIter = number<km[number<1>{}]>{};
// read A warp tensor from A block window
const auto a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter));
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor(nIter)(kIter));
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor(nIter)(kIter));
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
});
}

View File

@@ -210,45 +210,45 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg : public BlockGemmQuantBase
c_acc;
auto zero_accumulators = [&] {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
static_for<0, (WG::kM * WG::kN) / warp_size, 1>{}([&](auto i) {
c_acc(mIter)(nIter).get_thread_buffer()[i] = 0.0f;
}); // make sure WG::CWarpTensor exposes a clear/zero
static_ford<sequence<MIterPerWarp, NIterPerWarp, (WG::kM * WG::kN) / warp_size>>{}(
[&](auto mni) {
constexpr auto mIter = number<mni[number<0>{}]>{};
constexpr auto nIter = number<mni[number<1>{}]>{};
constexpr auto i = number<mni[number<2>{}]>{};
c_acc(mIter)(nIter).get_thread_buffer()[i] = 0.0f;
});
});
};
static_for<0, QScalesPerBlockRow, 1>{}([&](auto kQScale) {
zero_accumulators();
static_for<0, KIterPerQScale, 1>{}([&](auto kIterInQScale) {
constexpr auto kIter = kQScale * KIterPerQScale + kIterInQScale;
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// warp GEMM
WG{}(c_acc(mIter)(nIter),
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor(nIter)(number<kIter>{}));
});
__builtin_amdgcn_sched_barrier(0x7F6);
// preload next A from lds
if constexpr((kIter * MIterPerWarp + mIter) <
(KIterPerWarp * MIterPerWarp - m_preload))
{
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
load_and_convert_tile<UnaryOpSize>(
a_warp_tensor(number<AwarpIter>{}),
a_warp_windows(number<AmIter>{})(number<AkIter>{}));
}
// barrier
// Could be deleted
if constexpr((mIter == MIter_2nd_last))
{
block_sync_lds();
}
static_ford<sequence<KIterPerQScale, MIterPerWarp>>{}([&](auto km) {
constexpr auto kIterInQScale = number<km[number<0>{}]>{};
constexpr auto mIter = number<km[number<1>{}]>{};
constexpr auto kIter = kQScale * KIterPerQScale + kIterInQScale;
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// warp GEMM
WG{}(c_acc(mIter)(nIter),
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor(nIter)(number<kIter>{}));
});
__builtin_amdgcn_sched_barrier(0x7F6);
// preload next A from lds
if constexpr((kIter * MIterPerWarp + mIter) <
(KIterPerWarp * MIterPerWarp - m_preload))
{
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
load_and_convert_tile<UnaryOpSize>(
a_warp_tensor(number<AwarpIter>{}),
a_warp_windows(number<AmIter>{})(number<AkIter>{}));
}
// barrier
// Could be deleted
if constexpr((mIter == MIter_2nd_last))
{
block_sync_lds();
}
});
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
AQPickerCommon<AQBlockTensor, Traits, mIter, kQScale> aq_picker(aq_block_tensor);

View File

@@ -127,105 +127,103 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg
c_acc;
auto zero_accumulators = [&] {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
static_for<0, (WG::kM * WG::kN) / warp_size, 1>{}([&](auto i) {
c_acc(mIter)(nIter).get_thread_buffer()[i] = 0.0f;
}); // make sure WG::CWarpTensor exposes a clear/zero
static_ford<sequence<MIterPerWarp, NIterPerWarp, (WG::kM * WG::kN) / warp_size>>{}(
[&](auto mni) {
constexpr auto mIter = number<mni[number<0>{}]>{};
constexpr auto nIter = number<mni[number<1>{}]>{};
constexpr auto i = number<mni[number<2>{}]>{};
c_acc(mIter)(nIter).get_thread_buffer()[i] = 0.0f;
});
});
};
static_for<0, QScalesPerBlockRow, 1>{}([&](auto kQScale) {
zero_accumulators();
static_for<0, KIterPerQScale, 1>{}([&](auto kIterInQScale) {
constexpr auto kIter = kQScale * KIterPerQScale + kIterInQScale;
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// warp GEMM
WG{}(c_acc(mIter)(nIter),
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor(nIter)(number<kIter>{}));
});
__builtin_amdgcn_sched_barrier(0x7F6);
// preload next A from lds
if constexpr((kIter * MIterPerWarp + mIter) <
(KIterPerWarp * MIterPerWarp - m_preload))
{
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
a_warp_tensor(number<AwarpIter>{}) =
load_tile(a_warp_windows(number<AmIter>{})(number<AkIter>{}));
}
// barrier
// Could be deleted
if constexpr((mIter == MIter_2nd_last))
{
block_sync_lds();
}
});
});
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_ford<sequence<KIterPerQScale, MIterPerWarp>>{}([&](auto km) {
constexpr auto kIterInQScale = number<km[number<0>{}]>{};
constexpr auto mIter = number<km[number<1>{}]>{};
constexpr auto kIter = kQScale * KIterPerQScale + kIterInQScale;
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
constexpr auto tbuf_offset =
number<typename CBlockTensor::ThreadTensorDesc{}.calculate_offset(
merge_sequences(sequence<mIter, nIter>{},
c_warp_y_index_zeros)) /
CBlockTensor::PackedSize>{};
// warp GEMM
WG{}(c_acc(mIter)(nIter),
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor(nIter)(number<kIter>{}));
});
__builtin_amdgcn_sched_barrier(0x7F6);
// preload next A from lds
if constexpr((kIter * MIterPerWarp + mIter) <
(KIterPerWarp * MIterPerWarp - m_preload))
{
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
a_warp_tensor(number<AwarpIter>{}) =
load_tile(a_warp_windows(number<AmIter>{})(number<AkIter>{}));
}
// barrier
// Could be deleted
if constexpr((mIter == MIter_2nd_last))
{
block_sync_lds();
}
});
static_ford<sequence<MIterPerWarp, NIterPerWarp>>{}([&](auto mn) {
constexpr auto mIter = number<mn[number<0>{}]>{};
constexpr auto nIter = number<mn[number<1>{}]>{};
constexpr auto tbuf_offset =
number<typename CBlockTensor::ThreadTensorDesc{}.calculate_offset(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros)) /
CBlockTensor::PackedSize>{};
if constexpr(BPreshuffleQuant)
if constexpr(BPreshuffleQuant)
{
constexpr index_t reg_offset = nIter;
auto pull_from_lane = (__lane_id() & (WG::kN - 1)) * KPerBlockBQ + kQScale;
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
// cross lane ops
uint32_t scale_reg_dword;
if constexpr(std::is_same_v<BQDataType, float>)
{
constexpr index_t reg_offset = nIter;
auto pull_from_lane = (__lane_id() & (WG::kN - 1)) * KPerBlockBQ + kQScale;
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
// cross lane ops
uint32_t scale_reg_dword;
if constexpr(std::is_same_v<BQDataType, float>)
{
scale_reg_dword = ck_tile::bit_cast<uint32_t>(scale_reg);
}
else
{
scale_reg_dword = static_cast<uint32_t>(scale_reg);
}
// cross lane ops to get the value of scale_reg.
int gathered_scale_reg = __builtin_amdgcn_ds_bpermute(
pull_from_lane << 2, __builtin_bit_cast(int, scale_reg_dword));
float scale_reg_f = cvt_scale_to_fp32(gathered_scale_reg);
static_for<0, WG::kM * WG::kN / warp_size, 1>{}([&](auto c_row) {
auto& c_ref = c_block_tensor.get_thread_buffer()[tbuf_offset + c_row];
const auto acc_val = c_acc(mIter)(nIter).get_thread_buffer()[c_row];
c_ref = c_ref + acc_val * scale_reg_f;
});
scale_reg_dword = ck_tile::bit_cast<uint32_t>(scale_reg);
}
else
{
index_t reg_offset = [&]() {
if constexpr(BQuantGroupSize::kN >= (NWarp * WG::kN))
{
return (nIter * NWarp * WG::kN) / BQuantGroupSize::kN *
KPerBlockBQ +
kQScale;
}
else
{
return nIter * KPerBlockBQ + kQScale;
}
}();
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
float scale_reg_f = cvt_scale_to_fp32(scale_reg);
static_for<0, WG::kM * WG::kN / warp_size, 1>{}([&](auto c_row) {
auto& c_ref = c_block_tensor.get_thread_buffer()[tbuf_offset + c_row];
const auto acc_val = c_acc(mIter)(nIter).get_thread_buffer()[c_row];
c_ref = c_ref + acc_val * scale_reg_f;
});
scale_reg_dword = static_cast<uint32_t>(scale_reg);
}
});
// cross lane ops to get the value of scale_reg.
int gathered_scale_reg = __builtin_amdgcn_ds_bpermute(
pull_from_lane << 2, __builtin_bit_cast(int, scale_reg_dword));
float scale_reg_f = cvt_scale_to_fp32(gathered_scale_reg);
static_for<0, WG::kM * WG::kN / warp_size, 1>{}([&](auto c_row) {
auto& c_ref = c_block_tensor.get_thread_buffer()[tbuf_offset + c_row];
const auto acc_val = c_acc(mIter)(nIter).get_thread_buffer()[c_row];
c_ref = c_ref + acc_val * scale_reg_f;
});
}
else
{
index_t reg_offset = [&]() {
if constexpr(BQuantGroupSize::kN >= (NWarp * WG::kN))
{
return (nIter * NWarp * WG::kN) / BQuantGroupSize::kN * KPerBlockBQ +
kQScale;
}
else
{
return nIter * KPerBlockBQ + kQScale;
}
}();
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
float scale_reg_f = cvt_scale_to_fp32(scale_reg);
static_for<0, WG::kM * WG::kN / warp_size, 1>{}([&](auto c_row) {
auto& c_ref = c_block_tensor.get_thread_buffer()[tbuf_offset + c_row];
const auto acc_val = c_acc(mIter)(nIter).get_thread_buffer()[c_row];
c_ref = c_ref + acc_val * scale_reg_f;
});
}
});
});
}

View File

@@ -290,121 +290,115 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase
constexpr auto warp_size = get_warp_size();
// hot loop:
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
CWarpTensor c_warp_tensor;
static_ford<sequence<MIterPerWarp, NIterPerWarp>>{}([&](auto mn) {
constexpr auto mIter = number<mn[number<0>{}]>{};
constexpr auto nIter = number<mn[number<1>{}]>{};
CWarpTensor c_warp_tensor;
static_for<0, Traits::QScalesPerBlockRow, 1>{}([&](auto kQScale) {
static_for<0, Traits::KIterPerQScale, 1>{}([&](auto kIterInQScale) {
constexpr auto kIter = kQScale * Traits::KIterPerQScale + kIterInQScale;
static_for<0, Traits::QScalesPerBlockRow, 1>{}([&](auto kQScale) {
static_for<0, Traits::KIterPerQScale, 1>{}([&](auto kIterInQScale) {
constexpr auto kIter = kQScale * Traits::KIterPerQScale + kIterInQScale;
AWarpTensor a_warp_tensor;
a_warp_tensor.get_thread_buffer() =
a_warp_tile_.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
AWarpTensor a_warp_tensor;
a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
BWarpTensor b_warp_tensor;
b_warp_tensor.get_thread_buffer() =
b_warp_tile_.get_y_sliced_thread_data(
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
BWarpTensor b_warp_tensor;
b_warp_tensor.get_thread_buffer() = b_warp_tile_.get_y_sliced_thread_data(
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
if constexpr(kIterInQScale == 0)
{
c_warp_tensor = WarpGemm{}(a_warp_tensor, b_warp_tensor);
}
else
{
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
}
});
constexpr auto tbuf_offset =
number<typename CBlockTensor::ThreadTensorDesc{}.calculate_offset(
merge_sequences(sequence<mIter, nIter>{},
c_warp_y_index_zeros)) /
CBlockTensor::PackedSize>{};
// a_scale
AQPickerCommon<AQBlockTensor, Traits, mIter, kQScale> aq_picker(
aq_block_tensor);
if constexpr(BPreshuffleQuant)
if constexpr(kIterInQScale == 0)
{
constexpr index_t reg_offset = [&]() {
if constexpr(GemmTraits::BQuantGroupSize::kN >
(NWarp * WarpGemm::kN) &&
Traits::NPerBlock == GemmTraits::BQuantGroupSize::kN)
{
return kQScale;
}
else
{
return nIter;
}
}();
auto pull_from_lane =
(__lane_id() & (WarpGemm::kN - 1)) * Traits::KQPerBlock + kQScale;
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
// cross lane ops
uint32_t scale_reg_dword;
if constexpr(std::is_same_v<BQDataType, float>)
{
scale_reg_dword = ck_tile::bit_cast<uint32_t>(scale_reg);
}
else
{
scale_reg_dword = static_cast<uint32_t>(scale_reg);
}
// cross lane ops to get the value of scale_reg.
int gathered_scale_reg = __builtin_amdgcn_ds_bpermute(
pull_from_lane << 2, __builtin_bit_cast(int, scale_reg_dword));
float b_scale_reg_f =
Base::cvt_scale_to_fp32<typename Traits::BQDataType>(
gathered_scale_reg);
static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}(
[&](auto c_row) {
float a_scale_reg_f = aq_picker.template pick<c_row>();
c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] +=
(c_warp_tensor.get_thread_buffer()[c_row] * a_scale_reg_f *
b_scale_reg_f);
});
c_warp_tensor = WarpGemm{}(a_warp_tensor, b_warp_tensor);
}
else
{
// Multiply bquant with accumulated C
constexpr index_t reg_offset = [&]() {
if constexpr(GemmTraits::BQuantGroupSize::kN >=
(NWarp * WarpGemm::kN))
return (nIter * NWarp * WarpGemm::kN) /
GemmTraits::BQuantGroupSize::kN *
Traits::KQPerBlock +
kQScale;
else
{
return nIter * Traits::KQPerBlock + kQScale;
}
}();
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
float b_scale_reg_f =
Base::cvt_scale_to_fp32<typename Traits::BQDataType>(scale_reg);
static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}(
[&](auto c_row) {
float a_scale_reg_f = aq_picker.template pick<c_row>();
c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] +=
(c_warp_tensor.get_thread_buffer()[c_row] * a_scale_reg_f *
b_scale_reg_f);
});
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
}
});
constexpr auto tbuf_offset =
number<typename CBlockTensor::ThreadTensorDesc{}.calculate_offset(
merge_sequences(sequence<mIter, nIter>{},
c_warp_y_index_zeros)) /
CBlockTensor::PackedSize>{};
// a_scale
AQPickerCommon<AQBlockTensor, Traits, mIter, kQScale> aq_picker(
aq_block_tensor);
if constexpr(BPreshuffleQuant)
{
constexpr index_t reg_offset = [&]() {
if constexpr(GemmTraits::BQuantGroupSize::kN > (NWarp * WarpGemm::kN) &&
Traits::NPerBlock == GemmTraits::BQuantGroupSize::kN)
{
return kQScale;
}
else
{
return nIter;
}
}();
auto pull_from_lane =
(__lane_id() & (WarpGemm::kN - 1)) * Traits::KQPerBlock + kQScale;
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
// cross lane ops
uint32_t scale_reg_dword;
if constexpr(std::is_same_v<BQDataType, float>)
{
scale_reg_dword = ck_tile::bit_cast<uint32_t>(scale_reg);
}
else
{
scale_reg_dword = static_cast<uint32_t>(scale_reg);
}
// cross lane ops to get the value of scale_reg.
int gathered_scale_reg = __builtin_amdgcn_ds_bpermute(
pull_from_lane << 2, __builtin_bit_cast(int, scale_reg_dword));
float b_scale_reg_f = Base::cvt_scale_to_fp32<typename Traits::BQDataType>(
gathered_scale_reg);
static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}(
[&](auto c_row) {
float a_scale_reg_f = aq_picker.template pick<c_row>();
c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] +=
(c_warp_tensor.get_thread_buffer()[c_row] * a_scale_reg_f *
b_scale_reg_f);
});
}
else
{
// Multiply bquant with accumulated C
constexpr index_t reg_offset = [&]() {
if constexpr(GemmTraits::BQuantGroupSize::kN >= (NWarp * WarpGemm::kN))
return (nIter * NWarp * WarpGemm::kN) /
GemmTraits::BQuantGroupSize::kN * Traits::KQPerBlock +
kQScale;
else
{
return nIter * Traits::KQPerBlock + kQScale;
}
}();
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
float b_scale_reg_f =
Base::cvt_scale_to_fp32<typename Traits::BQDataType>(scale_reg);
static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}(
[&](auto c_row) {
float a_scale_reg_f = aq_picker.template pick<c_row>();
c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] +=
(c_warp_tensor.get_thread_buffer()[c_row] * a_scale_reg_f *
b_scale_reg_f);
});
}
});
});
}

View File

@@ -268,54 +268,51 @@ struct AQuantBlockUniversalGemmAsBsCr
constexpr auto warp_size = get_warp_size();
// hot loop:
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
CWarpTensor c_warp_tensor;
static_ford<sequence<MIterPerWarp, NIterPerWarp>>{}([&](auto mn) {
constexpr auto mIter = number<mn[number<0>{}]>{};
constexpr auto nIter = number<mn[number<1>{}]>{};
CWarpTensor c_warp_tensor;
// for every column in AQ
static_for<0, Traits::QScalesPerBlockRow, 1>{}([&](auto kQScale) {
// for every warp corresponding to a quantization scale
static_for<0, Traits::KIterPerQScale, 1>{}([&](auto kIterInQScale) {
constexpr auto kIter = kQScale * Traits::KIterPerQScale + kIterInQScale;
// for every column in AQ
static_for<0, Traits::QScalesPerBlockRow, 1>{}([&](auto kQScale) {
// for every warp corresponding to a quantization scale
static_for<0, Traits::KIterPerQScale, 1>{}([&](auto kIterInQScale) {
constexpr auto kIter = kQScale * Traits::KIterPerQScale + kIterInQScale;
AWarpTensor a_warp_tensor;
a_warp_tensor.get_thread_buffer() =
a_warp_tile_.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
AWarpTensor a_warp_tensor;
a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
BWarpTensor b_warp_tensor;
b_warp_tensor.get_thread_buffer() =
b_warp_tile_.get_y_sliced_thread_data(
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
BWarpTensor b_warp_tensor;
b_warp_tensor.get_thread_buffer() = b_warp_tile_.get_y_sliced_thread_data(
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
if constexpr(kIterInQScale == 0)
{
c_warp_tensor = WarpGemm{}(a_warp_tensor, b_warp_tensor);
}
else
{
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
}
});
if constexpr(kIterInQScale == 0)
{
c_warp_tensor = WarpGemm{}(a_warp_tensor, b_warp_tensor);
}
else
{
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
}
});
constexpr auto tbuf_offset =
number<typename CBlockTensor::ThreadTensorDesc{}.calculate_offset(
merge_sequences(sequence<mIter, nIter>{},
c_warp_y_index_zeros)) /
CBlockTensor::PackedSize>{};
constexpr auto tbuf_offset =
number<typename CBlockTensor::ThreadTensorDesc{}.calculate_offset(
merge_sequences(sequence<mIter, nIter>{},
c_warp_y_index_zeros)) /
CBlockTensor::PackedSize>{};
AQPickerCommon<AQBlockTensor, Traits, mIter, kQScale> aq_picker(
aq_block_tensor);
AQPickerCommon<AQBlockTensor, Traits, mIter, kQScale> aq_picker(
aq_block_tensor);
static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}(
[&](auto c_row) {
float scale_reg_f = aq_picker.template pick<c_row>();
static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}([&](auto c_row) {
float scale_reg_f = aq_picker.template pick<c_row>();
c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] +=
(c_warp_tensor.get_thread_buffer()[c_row] * scale_reg_f);
});
c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] +=
(c_warp_tensor.get_thread_buffer()[c_row] * scale_reg_f);
});
});
});

View File

@@ -290,57 +290,55 @@ struct BQuantBlockUniversalGemmAsBsCr
using SrcVectorRawType = ext_vector_t<BDataTypeRaw, UnaryOpSize_ / BPackedSize>;
using DstVectorType = ext_vector_t<ComputeDataType, UnaryOpSize_>;
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
static_for<0, Traits::QScalesPerBlockRow, 1>{}([&](auto kQScale) {
// B scale register offset
constexpr index_t reg_offset = [&]() {
if constexpr(GemmTraits::BQuantGroupSize::kN >= (NWarp * WarpGemm::kN))
return ((nIter * NWarp * WarpGemm::kN) /
GemmTraits::BQuantGroupSize::kN) *
Traits::KQPerBlock +
kQScale;
else
{
return nIter * Traits::KQPerBlock + kQScale;
}
}();
static_ford<sequence<NIterPerWarp, Traits::QScalesPerBlockRow>>{}([&](auto nk) {
constexpr auto nIter = number<nk[number<0>{}]>{};
constexpr auto kQScale = number<nk[number<1>{}]>{};
// B scale register offset
constexpr index_t reg_offset = [&]() {
if constexpr(GemmTraits::BQuantGroupSize::kN >= (NWarp * WarpGemm::kN))
return ((nIter * NWarp * WarpGemm::kN) / GemmTraits::BQuantGroupSize::kN) *
Traits::KQPerBlock +
kQScale;
else
{
return nIter * Traits::KQPerBlock + kQScale;
}
}();
// Get B scale from thread buffer
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
float b_scale_f = float(scale_reg);
// Get B scale from thread buffer
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
float b_scale_f = float(scale_reg);
static_for<0, Traits::KIterPerQScale, 1>{}([&](auto kIterInQScale) {
constexpr auto kIter = kQScale * Traits::KIterPerQScale + kIterInQScale;
// Thread buffers
using BWarpThreadBuffer = decltype(b_warp_tile_.get_y_sliced_thread_data(
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)));
using BLDSThreadBuffer = decltype(b_warp_tile_lds_.get_y_sliced_thread_data(
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)));
static_for<0, Traits::KIterPerQScale, 1>{}([&](auto kIterInQScale) {
constexpr auto kIter = kQScale * Traits::KIterPerQScale + kIterInQScale;
// Thread buffers
using BWarpThreadBuffer = decltype(b_warp_tile_.get_y_sliced_thread_data(
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)));
using BLDSThreadBuffer = decltype(b_warp_tile_lds_.get_y_sliced_thread_data(
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)));
BWarpThreadBuffer b_warp_thread_buffer;
BLDSThreadBuffer b_lds_thread_buffer;
BWarpThreadBuffer b_warp_thread_buffer;
BLDSThreadBuffer b_lds_thread_buffer;
// Load thread buffer from tile (LDS type)
b_lds_thread_buffer = b_warp_tile_lds_.get_y_sliced_thread_data(
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
// Load thread buffer from tile (LDS type)
b_lds_thread_buffer = b_warp_tile_lds_.get_y_sliced_thread_data(
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
// Apply scale to B thread buffer and cast
static_for<0, thread_buffer_size, 1>{}([&](auto i) {
elementwise_op(
b_warp_thread_buffer.template get_as<DstVectorType>()(i),
b_lds_thread_buffer.template get_as<SrcVectorRawType>()[i],
b_scale_f);
});
// Store B thread buffer to tile (MMA type)
b_warp_tile_.set_y_sliced_thread_data(
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths),
b_warp_thread_buffer);
// Apply scale to B thread buffer and cast
static_for<0, thread_buffer_size, 1>{}([&](auto i) {
elementwise_op(b_warp_thread_buffer.template get_as<DstVectorType>()(i),
b_lds_thread_buffer.template get_as<SrcVectorRawType>()[i],
b_scale_f);
});
// Store B thread buffer to tile (MMA type)
b_warp_tile_.set_y_sliced_thread_data(
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths),
b_warp_thread_buffer);
});
});
}
@@ -361,113 +359,107 @@ struct BQuantBlockUniversalGemmAsBsCr
constexpr auto warp_size = get_warp_size();
// hot loop:
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
CWarpTensor c_warp_tensor;
static_ford<sequence<MIterPerWarp, NIterPerWarp>>{}([&](auto mn) {
constexpr auto mIter = number<mn[number<0>{}]>{};
constexpr auto nIter = number<mn[number<1>{}]>{};
CWarpTensor c_warp_tensor;
static_for<0, Traits::QScalesPerBlockRow, 1>{}([&](auto kQScale) {
static_for<0, Traits::KIterPerQScale, 1>{}([&](auto kIterInQScale) {
constexpr auto kIter = kQScale * Traits::KIterPerQScale + kIterInQScale;
static_for<0, Traits::QScalesPerBlockRow, 1>{}([&](auto kQScale) {
static_for<0, Traits::KIterPerQScale, 1>{}([&](auto kIterInQScale) {
constexpr auto kIter = kQScale * Traits::KIterPerQScale + kIterInQScale;
AWarpTensor a_warp_tensor;
a_warp_tensor.get_thread_buffer() =
a_warp_tile_.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
AWarpTensor a_warp_tensor;
a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
BWarpTensor b_warp_tensor;
b_warp_tensor.get_thread_buffer() =
b_warp_tile_.get_y_sliced_thread_data(
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
BWarpTensor b_warp_tensor;
b_warp_tensor.get_thread_buffer() = b_warp_tile_.get_y_sliced_thread_data(
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
if constexpr(kIterInQScale == 0)
{
c_warp_tensor = WarpGemm{}(a_warp_tensor, b_warp_tensor);
}
else
{
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
}
});
constexpr auto tbuf_offset =
number<typename CBlockTensor::ThreadTensorDesc{}.calculate_offset(
merge_sequences(sequence<mIter, nIter>{},
c_warp_y_index_zeros)) /
CBlockTensor::PackedSize>{};
if constexpr(BPreshuffleQuant)
if constexpr(kIterInQScale == 0)
{
constexpr index_t reg_offset = [&]() {
if constexpr(GemmTraits::BQuantGroupSize::kN >
(NWarp * WarpGemm::kN) &&
Traits::NPerBlock == GemmTraits::BQuantGroupSize::kN)
{
return kQScale; // prefill: one quant group per block
}
else
{
return nIter; // decode or multiple groups per warp
}
}();
auto pull_from_lane =
(__lane_id() & (WarpGemm::kN - 1)) * Traits::KQPerBlock + kQScale;
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
// cross lane ops
uint32_t scale_reg_dword;
if constexpr(std::is_same_v<BQDataType, float>)
{
scale_reg_dword = ck_tile::bit_cast<uint32_t>(scale_reg);
}
else
{
scale_reg_dword = static_cast<uint32_t>(scale_reg);
}
// cross lane ops to get the value of scale_reg.
int gathered_scale_reg = __builtin_amdgcn_ds_bpermute(
pull_from_lane << 2, __builtin_bit_cast(int, scale_reg_dword));
float scale_reg_f =
Base::cvt_scale_to_fp32<typename Traits::BQDataType>(
gathered_scale_reg);
static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}(
[&](auto c_row) {
c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] +=
(c_warp_tensor.get_thread_buffer()[c_row] * scale_reg_f);
});
c_warp_tensor = WarpGemm{}(a_warp_tensor, b_warp_tensor);
}
else
{
// Multiply bquant with accumulated C
constexpr index_t reg_offset = [&]() {
if constexpr(GemmTraits::BQuantGroupSize::kN >=
(NWarp * WarpGemm::kN))
return (nIter * NWarp * WarpGemm::kN) /
GemmTraits::BQuantGroupSize::kN *
Traits::KQPerBlock +
kQScale;
else
{
return nIter * Traits::KQPerBlock + kQScale;
}
}();
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
float scale_reg_f =
Base::cvt_scale_to_fp32<typename Traits::BQDataType>(scale_reg);
static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}(
[&](auto c_row) {
c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] +=
(c_warp_tensor.get_thread_buffer()[c_row] * scale_reg_f);
});
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
}
});
constexpr auto tbuf_offset =
number<typename CBlockTensor::ThreadTensorDesc{}.calculate_offset(
merge_sequences(sequence<mIter, nIter>{},
c_warp_y_index_zeros)) /
CBlockTensor::PackedSize>{};
if constexpr(BPreshuffleQuant)
{
constexpr index_t reg_offset = [&]() {
if constexpr(GemmTraits::BQuantGroupSize::kN > (NWarp * WarpGemm::kN) &&
Traits::NPerBlock == GemmTraits::BQuantGroupSize::kN)
{
return kQScale; // prefill: one quant group per block
}
else
{
return nIter; // decode or multiple groups per warp
}
}();
auto pull_from_lane =
(__lane_id() & (WarpGemm::kN - 1)) * Traits::KQPerBlock + kQScale;
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
// cross lane ops
uint32_t scale_reg_dword;
if constexpr(std::is_same_v<BQDataType, float>)
{
scale_reg_dword = ck_tile::bit_cast<uint32_t>(scale_reg);
}
else
{
scale_reg_dword = static_cast<uint32_t>(scale_reg);
}
// cross lane ops to get the value of scale_reg.
int gathered_scale_reg = __builtin_amdgcn_ds_bpermute(
pull_from_lane << 2, __builtin_bit_cast(int, scale_reg_dword));
float scale_reg_f = Base::cvt_scale_to_fp32<typename Traits::BQDataType>(
gathered_scale_reg);
static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}(
[&](auto c_row) {
c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] +=
(c_warp_tensor.get_thread_buffer()[c_row] * scale_reg_f);
});
}
else
{
// Multiply bquant with accumulated C
constexpr index_t reg_offset = [&]() {
if constexpr(GemmTraits::BQuantGroupSize::kN >= (NWarp * WarpGemm::kN))
return (nIter * NWarp * WarpGemm::kN) /
GemmTraits::BQuantGroupSize::kN * Traits::KQPerBlock +
kQScale;
else
{
return nIter * Traits::KQPerBlock + kQScale;
}
}();
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
float scale_reg_f =
Base::cvt_scale_to_fp32<typename Traits::BQDataType>(scale_reg);
static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}(
[&](auto c_row) {
c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] +=
(c_warp_tensor.get_thread_buffer()[c_row] * scale_reg_f);
});
}
});
});
}

View File

@@ -288,22 +288,22 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
MIterPerWarp>
a_warp_windows_pong;
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp;
static_ford<sequence<MIterPerWarp, KIterPerWarp>>{}([&](auto mk) {
constexpr auto mIter = number<mk[number<0>{}]>{};
constexpr auto kIter = number<mk[number<1>{}]>{};
a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp;
move_tile_window(a_warp_windows_ping(mIter)(kIter),
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
});
move_tile_window(a_warp_windows_ping(mIter)(kIter),
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
});
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp;
static_ford<sequence<MIterPerWarp, KIterPerWarp>>{}([&](auto mk) {
constexpr auto mIter = number<mk[number<0>{}]>{};
constexpr auto kIter = number<mk[number<1>{}]>{};
a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp;
move_tile_window(a_warp_windows_pong(mIter)(kIter),
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
});
move_tile_window(a_warp_windows_pong(mIter)(kIter),
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
});
// Block GEMM
@@ -366,16 +366,16 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
// prefetch B
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
static_ford<sequence<NIterPerWarp, KIterPerWarp>>{}([&](auto nk) {
constexpr auto nIter = number<nk[number<0>{}]>{};
constexpr auto kIter = number<nk[number<1>{}]>{};
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * flatNPerWarp, kIter * flatKPerWarp});
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * flatNPerWarp, kIter * flatKPerWarp});
load_and_convert_tile<UnaryOpSize_>(b_warp_tensor_ping(nIter)(kIter),
b_flat_dram_windows(nIter)(kIter));
});
load_and_convert_tile<UnaryOpSize_>(b_warp_tensor_ping(nIter)(kIter),
b_flat_dram_windows(nIter)(kIter));
});
// move B window to next flat K
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
@@ -448,15 +448,15 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
bq_block_tile,
a_warp_windows_ping);
// prefetch B(2i+1)
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
static_ford<sequence<KIterPerWarp, NIterPerWarp>>{}([&](auto kn) {
constexpr auto kIter = number<kn[number<0>{}]>{};
constexpr auto nIter = number<kn[number<1>{}]>{};
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * flatNPerWarp, kIter * flatKPerWarp});
load_and_convert_tile<UnaryOpSize_>(b_warp_tensor_pong(nIter)(kIter),
b_flat_dram_windows(nIter)(kIter));
});
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * flatNPerWarp, kIter * flatKPerWarp});
load_and_convert_tile<UnaryOpSize_>(b_warp_tensor_pong(nIter)(kIter),
b_flat_dram_windows(nIter)(kIter));
});
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
aq_block_tile_2 = load_tile(aq_copy_dram_window);
@@ -473,15 +473,15 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
// Next K
// prefetch B(2i+2)
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
static_ford<sequence<KIterPerWarp, NIterPerWarp>>{}([&](auto kn) {
constexpr auto kIter = number<kn[number<0>{}]>{};
constexpr auto nIter = number<kn[number<1>{}]>{};
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * flatNPerWarp, kIter * flatKPerWarp});
load_and_convert_tile<UnaryOpSize_>(b_warp_tensor_ping(nIter)(kIter),
b_flat_dram_windows(nIter)(kIter));
});
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * flatNPerWarp, kIter * flatKPerWarp});
load_and_convert_tile<UnaryOpSize_>(b_warp_tensor_ping(nIter)(kIter),
b_flat_dram_windows(nIter)(kIter));
});
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
aq_block_tile = load_tile(aq_copy_dram_window);
@@ -520,16 +520,16 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
if constexpr(TailNum == TailNumber::Even)
{
// prefetch B(loopK)
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
static_ford<sequence<KIterPerWarp, NIterPerWarp>>{}([&](auto kn) {
constexpr auto kIter = number<kn[number<0>{}]>{};
constexpr auto nIter = number<kn[number<1>{}]>{};
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * flatNPerWarp, kIter * flatKPerWarp});
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * flatNPerWarp, kIter * flatKPerWarp});
load_and_convert_tile<UnaryOpSize_>(b_warp_tensor_pong(nIter)(kIter),
b_flat_dram_windows(nIter)(kIter));
});
load_and_convert_tile<UnaryOpSize_>(b_warp_tensor_pong(nIter)(kIter),
b_flat_dram_windows(nIter)(kIter));
});
aq_block_tile_2 = load_tile(aq_copy_dram_window);
bq_block_tile_2 = load_tile(bq_copy_dram_window);

View File

@@ -275,22 +275,22 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
MIterPerWarp>
a_warp_windows_pong;
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp;
static_ford<sequence<MIterPerWarp, KIterPerWarp>>{}([&](auto mk) {
constexpr auto mIter = number<mk[number<0>{}]>{};
constexpr auto kIter = number<mk[number<1>{}]>{};
a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp;
move_tile_window(a_warp_windows_ping(mIter)(kIter),
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
});
move_tile_window(a_warp_windows_ping(mIter)(kIter),
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
});
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp;
static_ford<sequence<MIterPerWarp, KIterPerWarp>>{}([&](auto mk) {
constexpr auto mIter = number<mk[number<0>{}]>{};
constexpr auto kIter = number<mk[number<1>{}]>{};
a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp;
move_tile_window(a_warp_windows_pong(mIter)(kIter),
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
});
move_tile_window(a_warp_windows_pong(mIter)(kIter),
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
});
// Block GEMM
@@ -337,16 +337,16 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
// prefetch B
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
static_ford<sequence<NIterPerWarp, KIterPerWarp>>{}([&](auto nk) {
constexpr auto nIter = number<nk[number<0>{}]>{};
constexpr auto kIter = number<nk[number<1>{}]>{};
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * flatNPerWarp, kIter * flatKPerWarp});
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * flatNPerWarp, kIter * flatKPerWarp});
load_and_convert_tile<UnaryOpSize_>(b_warp_tensor_ping(nIter)(kIter),
b_flat_dram_windows(nIter)(kIter));
});
load_and_convert_tile<UnaryOpSize_>(b_warp_tensor_ping(nIter)(kIter),
b_flat_dram_windows(nIter)(kIter));
});
// move B window to next flat K
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
@@ -424,15 +424,15 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
bq_block_tile,
a_warp_windows_ping);
// prefetch B(2i+1)
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
static_ford<sequence<KIterPerWarp, NIterPerWarp>>{}([&](auto kn) {
constexpr auto kIter = number<kn[number<0>{}]>{};
constexpr auto nIter = number<kn[number<1>{}]>{};
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * flatNPerWarp, kIter * flatKPerWarp});
load_and_convert_tile<UnaryOpSize_>(b_warp_tensor_pong(nIter)(kIter),
b_flat_dram_windows(nIter)(kIter));
});
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * flatNPerWarp, kIter * flatKPerWarp});
load_and_convert_tile<UnaryOpSize_>(b_warp_tensor_pong(nIter)(kIter),
b_flat_dram_windows(nIter)(kIter));
});
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
@@ -461,15 +461,15 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
// Next K
// prefetch B(2i+2)
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
static_ford<sequence<KIterPerWarp, NIterPerWarp>>{}([&](auto kn) {
constexpr auto kIter = number<kn[number<0>{}]>{};
constexpr auto nIter = number<kn[number<1>{}]>{};
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * flatNPerWarp, kIter * flatKPerWarp});
load_and_convert_tile<UnaryOpSize_>(b_warp_tensor_ping(nIter)(kIter),
b_flat_dram_windows(nIter)(kIter));
});
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * flatNPerWarp, kIter * flatKPerWarp});
load_and_convert_tile<UnaryOpSize_>(b_warp_tensor_ping(nIter)(kIter),
b_flat_dram_windows(nIter)(kIter));
});
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
@@ -518,16 +518,16 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
if constexpr(TailNum == TailNumber::Even)
{
// prefetch B(loopK)
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
static_ford<sequence<KIterPerWarp, NIterPerWarp>>{}([&](auto kn) {
constexpr auto kIter = number<kn[number<0>{}]>{};
constexpr auto nIter = number<kn[number<1>{}]>{};
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * flatNPerWarp, kIter * flatKPerWarp});
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * flatNPerWarp, kIter * flatKPerWarp});
load_and_convert_tile<UnaryOpSize_>(b_warp_tensor_pong(nIter)(kIter),
b_flat_dram_windows(nIter)(kIter));
});
load_and_convert_tile<UnaryOpSize_>(b_warp_tensor_pong(nIter)(kIter),
b_flat_dram_windows(nIter)(kIter));
});
bq_block_tile_2 = load_tile(bq_copy_dram_window);

View File

@@ -303,11 +303,11 @@ struct BlockNormReduceCrossWarpSync
index_t local_warp_id = warp_id / num_reduce_warps;
index_t local_smem_os = local_warp_id * num_reduce_warps;
smem_dtype all_scratch[thread_buf_size * num_reduce_warps];
static_for<0, thread_buf_size, 1>{}([&](auto i_0) {
static_for<0, num_reduce_warps, 1>{}([&](auto i_1) {
all_scratch[i_0 * num_reduce_warps + i_1] =
smem_ptr[i_0 * num_warps + local_smem_os + i_1];
});
static_ford<sequence<thread_buf_size, num_reduce_warps>>{}([&](auto ii) {
constexpr auto i_0 = number<ii[number<0>{}]>{};
constexpr auto i_1 = number<ii[number<1>{}]>{};
all_scratch[i_0 * num_reduce_warps + i_1] =
smem_ptr[i_0 * num_warps + local_smem_os + i_1];
});
block_sync_lds(); // TODO: we don't need sync here

View File

@@ -631,17 +631,17 @@ struct BlockReduce2dLinearCrossWarpSync
IndexDataType> all_indices;
// Load data from shared memory
static_for<0, thread_buf_size, 1>{}([&](auto i_0) {
static_for<0, num_reduce_warps, 1>{}([&](auto i_1) {
all_scratch[i_0 * num_reduce_warps + i_1] =
smem_ptr[i_0 * num_warps + local_smem_os + i_1];
static_ford<sequence<thread_buf_size, num_reduce_warps>>{}([&](auto ii) {
constexpr auto i_0 = number<ii[number<0>{}]>{};
constexpr auto i_1 = number<ii[number<1>{}]>{};
all_scratch[i_0 * num_reduce_warps + i_1] =
smem_ptr[i_0 * num_warps + local_smem_os + i_1];
if constexpr(kProcessIndex)
{
all_indices[i_0 * num_reduce_warps + i_1] =
smem_indices[i_0 * num_warps + local_smem_os + i_1];
}
});
if constexpr(kProcessIndex)
{
all_indices[i_0 * num_reduce_warps + i_1] =
smem_indices[i_0 * num_warps + local_smem_os + i_1];
}
});
block_sync_lds(); // TODO: we don't need sync here