[rocm-libraries] ROCm/rocm-libraries#5031 (commit 1d86a92)

[CK] Replace nested static_for with static_ford to reduce
 device IR function emissions [1B] (#5031)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Summary

### Rationale
CK's GPU kernels are among the slowest files in the ROCm build, with a
single translation
unit taking up to 10+ minutes. Profiling with `-ftime-trace` identified
nested `static_for`
loops as the root cause: each nesting level multiplies the number of
unique lambda IR
functions the compiler must process. A 2-level nest of `static_for<0, M,
1>` /
`static_for<0, N, 1>` produces M×N unique lambda types. With typical
GEMM dimensions
(M=16, N=4), a single nest generates 64 unique functions — and these
nests appear hundreds
of times across the codebase.

The LLVM backend's CGSCC (Call Graph Strongly Connected Components)
framework processes
each function independently, so reducing function count directly reduces
backend time.

### What changed
393 nested compile-time loop patterns across 73 files are converted to
`static_ford`, which
flattens multi-dimensional compile-time iteration into a single
`static_for` with index
decomposition. This eliminates 994 `static_for` nesting levels (42%
reduction).

Three pattern categories were converted:
- **Category A**: `static_for` wrapping `static_ford` — fold outer
dimension into ford
- **Category B**: nested `static_ford` — merge into single
higher-dimensional ford
- **Category C**: nested `static_for` chains — convert to single
`static_ford`

### Verification

**ASM equivalence: PASS — 51/51 device assembly files identical (gfx942
+ gfx1100)**

| Architecture | Files compared | Largest file | Result |
|---|---|---|---|
| gfx942 | 36 | 386,685 lines | ALL MATCH |
| gfx1100 | 15 | 47,769 lines | ALL MATCH |

**Build time (Wilcoxon signed-rank test, 7 paired trials):**

| Target | Pre (s) | Post (s) | Delta | p-value |
|---|---|---|---|---|
| bscale | 169 | 152 | **-9.8%** | 0.016 \* |
| xdl_v1234 | 207 | 194 | **-6.6%** | 0.016 \* |
| preshuffle | 275 | 264 | **-3.9%** | 0.016 \* |
| xdl_base | 142 | 137 | **-3.2%** | 0.031 \* |

**IR function counts (device backend, gfx942):**

| Target | InstFunc Δ | CodeGen Δ | Compiler Δ |
|---|---|---|---|
| bscale | -13,043 (-8.2%) | -2,103 (-3.5%) | -10.7% |
| xdl_v1234 | -9,431 (-5.7%) | +59 (+0.1%) | -5.2% |
| xdl_base | -6,162 (-4.9%) | -1,141 (-2.5%) | -2.2% |
| xdl_old | -3,234 (-3.7%) | -963 (-8.7%) | -3.3% |

### Value
- **994 fewer `static_for` nesting levels** (-42%) across 73 files
- **393 `static_ford` sites** created (from 4 pre-existing)
- **Up to 9.8% compile-time reduction** on representative targets
(statistically significant, p < 0.05)
- **Up to 13K fewer IR function instantiations** per translation unit
- Net -849 LOC from reduced indentation
- **Zero ASM changes** — identical device code output verified on gfx942
and gfx1100
- All scheduling barriers, `if constexpr` guards, and MFMA/WMMA
accumulation order preserved

### Files changed (73)
- `block/`: 47 files (GEMM pipelines — xdlops, wmma, moe, preshuffle,
blockscale variants)
- `grid/`: 20 files (softmax, normalization, reduction, attention,
layernorm)
- `thread/`: 5 files (tensor slice transfer, contraction, GEMM dlops,
reduction)
- `tensor_description/`: 1 file (tensor_adaptor)

## Test plan
- [x] `static_ford` tested with 21 unit tests in
`test/util/unit_ford.cpp`
  (1D-4D, custom orders, compile-time verification)
- [x] All conversions preserve iteration order, `block_sync_lds()`
placement,
  `if constexpr` scheduling guards, and MFMA/WMMA accumulation order
- [x] ASM equivalence verified: 51 device `.s` files across gfx942 +
gfx1100
- [x] Build-time improvement statistically confirmed (Wilcoxon, p <
0.05, 4 targets)
- [x] IR function count reduction confirmed via `-ftime-trace` on 7
targets
- [x] Detection script reports 0 remaining safe patterns (180 blocked
with structural reasons)
- [x] Existing CI tests (GEMM, softmax, normalization, batch norm,
reduction,
  attention) exercise all converted code paths

## Submission Checklist

- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Christopher Millette
2026-03-18 14:46:50 +00:00
committed by assistant-librarian[bot]
parent 5f90f69795
commit e5683e2290
73 changed files with 9398 additions and 10247 deletions

View File

@@ -202,22 +202,22 @@ struct BlockwiseGemmWmmaops_pipeline_base
using AScaleThreadDesc = decltype(AScaleStruct::scale_thread_desc);
using BScaleThreadDesc = decltype(BScaleStruct::scale_thread_desc);
static_for<0, num_scale_m_block, 1>{}([&](auto m0) {
static_for<0, num_scale_n_block, 1>{}([&](auto n0) {
static_for<0, num_scale_k_block, 1>{}([&](auto k0) {
constexpr index_t c_offset =
CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0));
constexpr index_t a_offset =
AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0));
constexpr index_t b_offset =
BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0));
static_ford<Sequence<num_scale_m_block, num_scale_n_block, num_scale_k_block>>{}(
[&](auto mnk) {
constexpr auto m0 = Number<mnk[Number<0>{}]>{};
constexpr auto n0 = Number<mnk[Number<1>{}]>{};
constexpr auto k0 = Number<mnk[Number<2>{}]>{};
constexpr index_t c_offset =
CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0));
constexpr index_t a_offset =
AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0));
constexpr index_t b_offset =
BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0));
c_scale_thread_bufs(I0)(Number<c_offset>{}) =
a_scale_struct.scale_thread_bufs(I0)[Number<a_offset>{}] *
b_scale_struct.scale_thread_bufs(I0)[Number<b_offset>{}];
});
c_scale_thread_bufs(I0)(Number<c_offset>{}) =
a_scale_struct.scale_thread_bufs(I0)[Number<a_offset>{}] *
b_scale_struct.scale_thread_bufs(I0)[Number<b_offset>{}];
});
});
}
__device__ void Clear()

View File

@@ -224,87 +224,75 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
auto blockwise_gemm_func = [&]() {
// Local load
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
make_tuple(I0, m0, k0, I0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, I0, I0, I0, I0, I0, I0),
a_thread_buf);
if constexpr(m0 == I0)
static_ford<Sequence<KRepeat, MRepeat>>{}([&](auto km) {
constexpr auto k0 = Number<km[Number<0>{}]>{};
constexpr auto m0 = Number<km[Number<1>{}]>{};
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
make_tuple(I0, m0, k0, I0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, I0, I0, I0, I0, I0, I0),
a_thread_buf);
if constexpr(m0 == I0)
{
if constexpr(ck::is_same<BScaleStruct, Empty>::value == true)
{
if constexpr(ck::is_same<BScaleStruct, Empty>::value == true)
{
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
make_tuple(I0, n0, k0, I0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, n0, I0, I0, I0, I0, I0),
b_thread_buf);
});
}
else
{
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(
b_block_desc_k0_n0_n1_n2_k1,
make_tuple(I0, n0, k0, I0, I0, I0, I0),
b_block_buf,
b_scale_struct.scale_thread_bufs(
I0)[Number<n0 * BScaleStruct::num_scale_k_block +
k0 / BScaleStruct::num_scale_krepeat>{}],
b_thread_desc_,
make_tuple(I0, n0, I0, I0, I0, I0, I0),
b_thread_buf);
});
}
}
static_for<0, KInner, 1>{}([&](auto k_inner) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(Number<kk / A_K1>{},
I0,
I0,
I0,
I0,
I0,
Number<kk % A_K1>{}))>{}];
});
static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(Number<kk / B_K1>{},
n0,
I0,
I0,
I0,
I0,
Number<kk % B_K1>{}))>{}];
});
using wmma_input_type_a =
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
using wmma_input_type_b =
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
make_tuple(I0, n0, k0, I0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, n0, I0, I0, I0, I0, I0),
b_thread_buf);
});
}
else
{
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(
b_block_desc_k0_n0_n1_n2_k1,
make_tuple(I0, n0, k0, I0, I0, I0, I0),
b_block_buf,
b_scale_struct.scale_thread_bufs(
I0)[Number<n0 * BScaleStruct::num_scale_k_block +
k0 / BScaleStruct::num_scale_krepeat>{}],
b_thread_desc_,
make_tuple(I0, n0, I0, I0, I0, I0, I0),
b_thread_buf);
});
}
}
static_ford<Sequence<KInner, NRepeat>>{}([&](auto kn) {
constexpr auto k_inner = Number<kn[Number<0>{}]>{};
constexpr auto n0 = Number<kn[Number<1>{}]>{};
vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple(
Number<kk / A_K1>{}, I0, I0, I0, I0, I0, Number<kk % A_K1>{}))>{}];
});
static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(make_tuple(
Number<kk / B_K1>{}, n0, I0, I0, I0, I0, Number<kk % B_K1>{}))>{}];
});
using wmma_input_type_a =
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
using wmma_input_type_b =
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
};
@@ -341,20 +329,17 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
static_for<0, num_buffer_load_inst, 1>{}([&](auto) {
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
});
static_for<0, KRepeat, 1>{}([&](auto) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
if constexpr(m0 == I0)
{
static_for<0, NRepeat, 1>{}([&](auto) {
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
});
}
static_for<0, KInner, 1>{}([&](auto) {
static_for<0, NRepeat, 1>{}([&](auto) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // WMMA
});
static_ford<Sequence<KRepeat, MRepeat>>{}([&](auto km) {
constexpr auto m0 = Number<km[Number<1>{}]>{};
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
if constexpr(m0 == I0)
{
static_for<0, NRepeat, 1>{}([&](auto) {
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
});
}
static_ford<Sequence<KInner, NRepeat>>{}([&](auto) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // WMMA
});
});
static_for<0, num_ds_write_inst, 1>{}([&](auto) {
@@ -464,59 +449,55 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
});
});
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, NumScaleKBlock, 1>{}([&](auto kscale0) {
c_scale_struct.Clear();
static_for<0, KRepeat / NumScaleKBlock, 1>{}([&](auto k0) {
vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
static_ford<Sequence<MRepeat, NRepeat, NumScaleKBlock>>{}([&](auto mnk) {
constexpr auto m0 = Number<mnk[Number<0>{}]>{};
constexpr auto n0 = Number<mnk[Number<1>{}]>{};
constexpr auto kscale0 = Number<mnk[Number<2>{}]>{};
c_scale_struct.Clear();
static_for<0, KRepeat / NumScaleKBlock, 1>{}([&](auto k0) {
vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
static_for<0, KInner, 1>{}([&](auto k_inner) {
static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
constexpr index_t k_index =
kscale0 * (KRepeat / NumScaleKBlock) + k0;
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<Base::a_thread_desc_.CalculateOffset(
make_tuple(Number<kk / A_K1>{},
m0,
k_index,
I0,
I0,
I0,
Number<kk % A_K1>{}))>{}];
});
static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
constexpr index_t k_index =
kscale0 * (KRepeat / NumScaleKBlock) + k0;
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<Base::b_thread_desc_.CalculateOffset(
make_tuple(Number<kk / B_K1>{},
n0,
k_index,
I0,
I0,
I0,
Number<kk % B_K1>{}))>{}];
});
using wmma_input_type_a =
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
using wmma_input_type_b =
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
wmma_gemm.Run(
a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(),
c_scale_struct.c_thread_buf_per_scale.GetVectorTypeReference(
Number<0>{}));
});
static_for<0, KInner, 1>{}([&](auto k_inner) {
static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
constexpr index_t k_index = kscale0 * (KRepeat / NumScaleKBlock) + k0;
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<Base::a_thread_desc_.CalculateOffset(
make_tuple(Number<kk / A_K1>{},
m0,
k_index,
I0,
I0,
I0,
Number<kk % A_K1>{}))>{}];
});
c_scale_struct.template UpdateCThreadBuf<kscale0, m0, n0>(c_thread_buf);
static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
constexpr index_t k_index = kscale0 * (KRepeat / NumScaleKBlock) + k0;
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<Base::b_thread_desc_.CalculateOffset(
make_tuple(Number<kk / B_K1>{},
n0,
k_index,
I0,
I0,
I0,
Number<kk % B_K1>{}))>{}];
});
using wmma_input_type_a =
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
using wmma_input_type_b =
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(),
c_scale_struct.c_thread_buf_per_scale.GetVectorTypeReference(
Number<0>{}));
});
});
c_scale_struct.template UpdateCThreadBuf<kscale0, m0, n0>(c_thread_buf);
});
};
@@ -850,73 +831,71 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
__builtin_amdgcn_s_barrier();
__builtin_amdgcn_sched_barrier(0);
}
static_for<0, KRepeatPerCluster, 1>{}([&](auto k0_inner) {
static_for<0, KInner, 1>{}([&](auto k_inner) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
static_ford<Sequence<KRepeatPerCluster, KInner, MRepeat, NRepeat>>{}(
[&](auto kkmn) {
constexpr auto k0_inner = Number<kkmn[Number<0>{}]>{};
constexpr auto k_inner = Number<kkmn[Number<1>{}]>{};
constexpr auto m0 = Number<kkmn[Number<2>{}]>{};
constexpr auto n0 = Number<kkmn[Number<3>{}]>{};
vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(Number<kk / A_K1>{},
m0,
k0_inner,
I0,
I0,
I0,
Number<kk % A_K1>{}))>{}];
});
static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(Number<kk / B_K1>{},
n0,
k0_inner,
I0,
I0,
I0,
Number<kk % B_K1>{}))>{}];
});
using wmma_input_type_a =
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
using wmma_input_type_b =
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
// The block_sync_lds() here performs double duty:
// A) safeguard against data hazard.
// B) reduce VMEM FIFO congestion by applying small delays to
// different wavefronts.
// It is performed near the end of MAC cluster to minimize lgkmcnt
// penalty
if constexpr(k0_offset + k0_inner == KRepeat - 1 &&
m0 == MRepeat - 1 && n0 == NRepeat - 1)
{
__builtin_amdgcn_sched_barrier(0);
block_sync_lds();
__builtin_amdgcn_sched_barrier(0);
}
wmma_gemm.Run(
a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
if constexpr(k0_inner == 0 && m0 == 0 && n0 == 0)
{
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_setprio(1);
__builtin_amdgcn_sched_barrier(0);
}
});
static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(Number<kk / A_K1>{},
m0,
k0_inner,
I0,
I0,
I0,
Number<kk % A_K1>{}))>{}];
});
static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(Number<kk / B_K1>{},
n0,
k0_inner,
I0,
I0,
I0,
Number<kk % B_K1>{}))>{}];
});
using wmma_input_type_a =
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
using wmma_input_type_b =
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
// The block_sync_lds() here performs double duty:
// A) safeguard against data hazard.
// B) reduce VMEM FIFO congestion by applying small delays to
// different wavefronts.
// It is performed near the end of MAC cluster to minimize lgkmcnt
// penalty
if constexpr(k0_offset + k0_inner == KRepeat - 1 && m0 == MRepeat - 1 &&
n0 == NRepeat - 1)
{
__builtin_amdgcn_sched_barrier(0);
block_sync_lds();
__builtin_amdgcn_sched_barrier(0);
}
wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
if constexpr(k0_inner == 0 && m0 == 0 && n0 == 0)
{
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_setprio(1);
__builtin_amdgcn_sched_barrier(0);
}
});
});
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_setprio(0);
@@ -1249,15 +1228,15 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
// Local prefetch A1
block_sync_lds();
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
make_tuple(I0, m0, k0, I0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, m0, k0, I0, I0, I0, I0),
a_thread_buf);
});
static_ford<Sequence<MRepeat, KRepeat>>{}([&](auto mk) {
constexpr auto m0 = Number<mk[Number<0>{}]>{};
constexpr auto k0 = Number<mk[Number<1>{}]>{};
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
make_tuple(I0, m0, k0, I0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, m0, k0, I0, I0, I0, I0),
a_thread_buf);
});
// Initialize C
@@ -1287,66 +1266,64 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
static_for<0, KInner, 1>{}([&](auto k_inner) {
static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(Number<kk / A_K1>{},
m0,
k0,
I0,
I0,
I0,
Number<kk % A_K1>{}))>{}];
});
static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_bufs[wmma_reg_buf]
[Number<b_thread_desc_.CalculateOffset(
make_tuple(Number<kk / B_K1>{},
I0,
I0,
n0,
I0,
k0,
Number<kk % B_K1>{}))>{}];
});
using wmma_input_type_a =
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
using wmma_input_type_b =
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
wmma_gemm.Run(
a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
static_ford<Sequence<MRepeat, NRepeat, KRepeat>>{}([&](auto mnk) {
constexpr auto m0 = Number<mnk[Number<0>{}]>{};
constexpr auto n0 = Number<mnk[Number<1>{}]>{};
constexpr auto k0 = Number<mnk[Number<2>{}]>{};
vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
static_for<0, KInner, 1>{}([&](auto k_inner) {
static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(Number<kk / A_K1>{},
m0,
k0,
I0,
I0,
I0,
Number<kk % A_K1>{}))>{}];
});
static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_bufs[wmma_reg_buf]
[Number<b_thread_desc_.CalculateOffset(
make_tuple(Number<kk / B_K1>{},
I0,
I0,
n0,
I0,
k0,
Number<kk % B_K1>{}))>{}];
});
using wmma_input_type_a =
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
using wmma_input_type_b =
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
block_sync_lds();
// loop prefetch copy
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
make_tuple(I0, m0, k0, I0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, m0, k0, I0, I0, I0, I0),
a_thread_buf);
});
static_ford<Sequence<MRepeat, KRepeat>>{}([&](auto mk) {
constexpr auto m0 = Number<mk[Number<0>{}]>{};
constexpr auto k0 = Number<mk[Number<1>{}]>{};
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
make_tuple(I0, m0, k0, I0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, m0, k0, I0, I0, I0, I0),
a_thread_buf);
});
HotLoopScheduler();
@@ -1373,112 +1350,86 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
static_for<0, KInner, 1>{}([&](auto k_inner) {
static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(Number<kk / A_K1>{},
m0,
k0,
I0,
I0,
I0,
Number<kk % A_K1>{}))>{}];
});
static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(Number<kk / B_K1>{},
I0,
I0,
n0,
I0,
k0,
Number<kk % B_K1>{}))>{}];
});
using wmma_input_type_a =
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
using wmma_input_type_b =
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
static_ford<Sequence<MRepeat, NRepeat, KRepeat>>{}([&](auto mnk) {
constexpr auto m0 = Number<mnk[Number<0>{}]>{};
constexpr auto n0 = Number<mnk[Number<1>{}]>{};
constexpr auto k0 = Number<mnk[Number<2>{}]>{};
vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
static_for<0, KInner, 1>{}([&](auto k_inner) {
static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple(
Number<kk / A_K1>{}, m0, k0, I0, I0, I0, Number<kk % A_K1>{}))>{}];
});
static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(make_tuple(
Number<kk / B_K1>{}, I0, I0, n0, I0, k0, Number<kk % B_K1>{}))>{}];
});
using wmma_input_type_a =
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
using wmma_input_type_b =
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
block_sync_lds();
// tail Local Prefetch A1
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
make_tuple(I0, m0, k0, I0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, m0, k0, I0, I0, I0, I0),
a_thread_buf);
});
static_ford<Sequence<MRepeat, KRepeat>>{}([&](auto mk) {
constexpr auto m0 = Number<mk[Number<0>{}]>{};
constexpr auto k0 = Number<mk[Number<1>{}]>{};
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
make_tuple(I0, m0, k0, I0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, m0, k0, I0, I0, I0, I0),
a_thread_buf);
});
__builtin_amdgcn_sched_barrier(0);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
static_for<0, KInner, 1>{}([&](auto k_inner) {
static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(Number<kk / A_K1>{},
m0,
k0,
I0,
I0,
I0,
Number<kk % A_K1>{}))>{}];
});
static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_bufs[I1][Number<b_thread_desc_.CalculateOffset(
make_tuple(Number<kk / B_K1>{},
I0,
I0,
n0,
I0,
k0,
Number<kk % B_K1>{}))>{}];
});
using wmma_input_type_a =
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
using wmma_input_type_b =
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
static_ford<Sequence<MRepeat, NRepeat, KRepeat>>{}([&](auto mnk) {
constexpr auto m0 = Number<mnk[Number<0>{}]>{};
constexpr auto n0 = Number<mnk[Number<1>{}]>{};
constexpr auto k0 = Number<mnk[Number<2>{}]>{};
vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
static_for<0, KInner, 1>{}([&](auto k_inner) {
static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple(
Number<kk / A_K1>{}, m0, k0, I0, I0, I0, Number<kk % A_K1>{}))>{}];
});
static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_bufs[I1][Number<b_thread_desc_.CalculateOffset(make_tuple(
Number<kk / B_K1>{}, I0, I0, n0, I0, k0, Number<kk % B_K1>{}))>{}];
});
using wmma_input_type_a =
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
using wmma_input_type_b =
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
// Let's leak last WMMA block to epilogue region, cover the potential lds-shuffle
@@ -1487,49 +1438,36 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
}
else if constexpr(TailNum == TailNumber::Odd)
{
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
static_for<0, KInner, 1>{}([&](auto k_inner) {
static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(Number<kk / A_K1>{},
m0,
k0,
I0,
I0,
I0,
Number<kk % A_K1>{}))>{}];
});
static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(Number<kk / B_K1>{},
I0,
I0,
n0,
I0,
k0,
Number<kk % B_K1>{}))>{}];
});
using wmma_input_type_a =
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
using wmma_input_type_b =
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
static_ford<Sequence<MRepeat, NRepeat, KRepeat>>{}([&](auto mnk) {
constexpr auto m0 = Number<mnk[Number<0>{}]>{};
constexpr auto n0 = Number<mnk[Number<1>{}]>{};
constexpr auto k0 = Number<mnk[Number<2>{}]>{};
vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
static_for<0, KInner, 1>{}([&](auto k_inner) {
static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple(
Number<kk / A_K1>{}, m0, k0, I0, I0, I0, Number<kk % A_K1>{}))>{}];
});
static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(make_tuple(
Number<kk / B_K1>{}, I0, I0, n0, I0, k0, Number<kk % B_K1>{}))>{}];
});
using wmma_input_type_a =
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
using wmma_input_type_b =
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
}
@@ -1590,70 +1528,65 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
auto c_scale_struct = CScaleStruct{};
auto gemm_core_func = [&](auto reg_buf) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, NumScaleKBlock, 1>{}([&](auto kscale0) {
c_scale_struct.Clear();
static_for<0, KRepeat / NumScaleKBlock, 1>{}([&](auto k0) {
vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
static_for<0, KInner, 1>{}([&](auto k_inner) {
static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
constexpr index_t k_index =
kscale0 * (KRepeat / NumScaleKBlock) + k0;
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(Number<kk / A_K1>{},
m0,
k_index,
I0,
I0,
I0,
Number<kk % A_K1>{}))>{}];
});
static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
constexpr index_t k_index =
kscale0 * (KRepeat / NumScaleKBlock) + k0;
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_bufs[reg_buf]
[Number<b_thread_desc_.CalculateOffset(
make_tuple(Number<kk / B_K1>{},
I0,
I0,
n0,
I0,
k_index,
Number<kk % B_K1>{}))>{}];
});
using wmma_input_type_a =
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
using wmma_input_type_b =
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
wmma_gemm.Run(
a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(),
c_scale_struct.c_thread_buf_per_scale.GetVectorTypeReference(
Number<0>{}));
});
static_ford<Sequence<MRepeat, NRepeat, NumScaleKBlock>>{}([&](auto mnk) {
constexpr auto m0 = Number<mnk[Number<0>{}]>{};
constexpr auto n0 = Number<mnk[Number<1>{}]>{};
constexpr auto kscale0 = Number<mnk[Number<2>{}]>{};
c_scale_struct.Clear();
static_for<0, KRepeat / NumScaleKBlock, 1>{}([&](auto k0) {
vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
static_for<0, KInner, 1>{}([&](auto k_inner) {
static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
constexpr index_t k_index = kscale0 * (KRepeat / NumScaleKBlock) + k0;
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(Number<kk / A_K1>{},
m0,
k_index,
I0,
I0,
I0,
Number<kk % A_K1>{}))>{}];
});
c_scale_struct.template UpdateCThreadBuf<kscale0, m0, n0>(c_thread_buf);
static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
constexpr index_t k_index = kscale0 * (KRepeat / NumScaleKBlock) + k0;
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_bufs[reg_buf][Number<b_thread_desc_.CalculateOffset(
make_tuple(Number<kk / B_K1>{},
I0,
I0,
n0,
I0,
k_index,
Number<kk % B_K1>{}))>{}];
});
using wmma_input_type_a =
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
using wmma_input_type_b =
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(),
c_scale_struct.c_thread_buf_per_scale.GetVectorTypeReference(
Number<0>{}));
});
});
c_scale_struct.template UpdateCThreadBuf<kscale0, m0, n0>(c_thread_buf);
});
};
auto a_local_prefetch_func = [&]() {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
make_tuple(I0, m0, k0, I0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, m0, k0, I0, I0, I0, I0),
a_thread_buf);
});
static_ford<Sequence<MRepeat, KRepeat>>{}([&](auto mk) {
constexpr auto m0 = Number<mk[Number<0>{}]>{};
constexpr auto k0 = Number<mk[Number<1>{}]>{};
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
make_tuple(I0, m0, k0, I0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, m0, k0, I0, I0, I0, I0),
a_thread_buf);
});
};

View File

@@ -434,53 +434,38 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
b_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0);
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KInner, 1>{}([&](auto k_inner) {
vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
static_ford<Sequence<KRepeat, MRepeat, NRepeat, KInner>>{}([&](auto kmnk) {
constexpr auto k0 = Number<kmnk[Number<0>{}]>{};
constexpr auto m0 = Number<kmnk[Number<1>{}]>{};
constexpr auto n0 = Number<kmnk[Number<2>{}]>{};
constexpr auto k_inner = Number<kmnk[Number<3>{}]>{};
vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(Number<kk / A_K1>{},
m0,
k0,
I0,
I0,
I0,
Number<kk % A_K1>{}))>{}];
});
static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(Number<kk / B_K1>{},
n0,
k0,
I0,
I0,
I0,
Number<kk % B_K1>{}))>{}];
});
using wmma_input_type_a =
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
using wmma_input_type_b =
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
wmma_gemm.Run(
a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple(
Number<kk / A_K1>{}, m0, k0, I0, I0, I0, Number<kk % A_K1>{}))>{}];
});
static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(make_tuple(
Number<kk / B_K1>{}, n0, k0, I0, I0, I0, Number<kk % B_K1>{}))>{}];
});
using wmma_input_type_a =
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
using wmma_input_type_b =
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
block_sync_lds();
@@ -506,52 +491,35 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
b_scale_struct.template GlobalLoad<0>(num_loop % num_loop_per_scale == 0);
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KInner, 1>{}([&](auto k_inner) {
vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
static_ford<Sequence<KRepeat, MRepeat, NRepeat, KInner>>{}([&](auto kmnk) {
constexpr auto k0 = Number<kmnk[Number<0>{}]>{};
constexpr auto m0 = Number<kmnk[Number<1>{}]>{};
constexpr auto n0 = Number<kmnk[Number<2>{}]>{};
constexpr auto k_inner = Number<kmnk[Number<3>{}]>{};
vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(Number<kk / A_K1>{},
m0,
k0,
I0,
I0,
I0,
Number<kk % A_K1>{}))>{}];
});
static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(Number<kk / B_K1>{},
n0,
k0,
I0,
I0,
I0,
Number<kk % B_K1>{}))>{}];
});
using wmma_input_type_a =
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
using wmma_input_type_b =
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple(
Number<kk / A_K1>{}, m0, k0, I0, I0, I0, Number<kk % A_K1>{}))>{}];
});
static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(make_tuple(
Number<kk / B_K1>{}, n0, k0, I0, I0, I0, Number<kk % B_K1>{}))>{}];
});
using wmma_input_type_a = typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
using wmma_input_type_b = typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
block_sync_lds();
@@ -564,52 +532,35 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
// Tail, always perform.
{
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KInner, 1>{}([&](auto k_inner) {
vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
static_ford<Sequence<KRepeat, MRepeat, NRepeat, KInner>>{}([&](auto kmnk) {
constexpr auto k0 = Number<kmnk[Number<0>{}]>{};
constexpr auto m0 = Number<kmnk[Number<1>{}]>{};
constexpr auto n0 = Number<kmnk[Number<2>{}]>{};
constexpr auto k_inner = Number<kmnk[Number<3>{}]>{};
vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(Number<kk / A_K1>{},
m0,
k0,
I0,
I0,
I0,
Number<kk % A_K1>{}))>{}];
});
static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(Number<kk / B_K1>{},
n0,
k0,
I0,
I0,
I0,
Number<kk % B_K1>{}))>{}];
});
using wmma_input_type_a =
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
using wmma_input_type_b =
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple(
Number<kk / A_K1>{}, m0, k0, I0, I0, I0, Number<kk % A_K1>{}))>{}];
});
static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(make_tuple(
Number<kk / B_K1>{}, n0, k0, I0, I0, I0, Number<kk % B_K1>{}))>{}];
});
using wmma_input_type_a = typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
using wmma_input_type_b = typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
// Let's leak last WMMA block to epilogue region, cover the potential lds-shuffle
// latency
@@ -747,58 +698,55 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
a_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0);
b_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, NumScaleKBlock, 1>{}([&](auto kscale0) {
c_scale_struct.Clear();
static_for<0, KRepeat / NumScaleKBlock, 1>{}([&](auto k0) {
static_for<0, KInner, 1>{}([&](auto k_inner) {
vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
static_ford<Sequence<MRepeat, NRepeat, NumScaleKBlock>>{}([&](auto mnk) {
constexpr auto m0 = Number<mnk[Number<0>{}]>{};
constexpr auto n0 = Number<mnk[Number<1>{}]>{};
constexpr auto kscale0 = Number<mnk[Number<2>{}]>{};
c_scale_struct.Clear();
static_ford<Sequence<KRepeat / NumScaleKBlock, KInner>>{}([&](auto kk_id) {
constexpr auto k0 = Number<kk_id[Number<0>{}]>{};
constexpr auto k_inner = Number<kk_id[Number<1>{}]>{};
vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
constexpr index_t k_index =
kscale0 * (KRepeat / NumScaleKBlock) + k0;
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(Number<kk / A_K1>{},
m0,
k_index,
I0,
I0,
I0,
Number<kk % A_K1>{}))>{}];
});
static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
constexpr index_t k_index =
kscale0 * (KRepeat / NumScaleKBlock) + k0;
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(Number<kk / B_K1>{},
n0,
k_index,
I0,
I0,
I0,
Number<kk % B_K1>{}))>{}];
});
using wmma_input_type_a =
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
using wmma_input_type_b =
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(),
c_scale_struct.c_thread_buf_per_scale
.GetVectorTypeReference(Number<0>{}));
});
});
c_scale_struct.template UpdateCThreadBuf<kscale0, m0, n0>(c_thread_buf);
static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
constexpr index_t k_index = kscale0 * (KRepeat / NumScaleKBlock) + k0;
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(Number<kk / A_K1>{},
m0,
k_index,
I0,
I0,
I0,
Number<kk % A_K1>{}))>{}];
});
static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
constexpr index_t k_index = kscale0 * (KRepeat / NumScaleKBlock) + k0;
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(Number<kk / B_K1>{},
n0,
k_index,
I0,
I0,
I0,
Number<kk % B_K1>{}))>{}];
});
using wmma_input_type_a =
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
using wmma_input_type_b =
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(),
c_scale_struct.c_thread_buf_per_scale.GetVectorTypeReference(
Number<0>{}));
});
c_scale_struct.template UpdateCThreadBuf<kscale0, m0, n0>(c_thread_buf);
});
c_scale_struct.Load(a_scale_struct, b_scale_struct);
@@ -825,59 +773,55 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
a_scale_struct.template GlobalLoad<0>(num_loop % num_loop_per_scale == 0);
b_scale_struct.template GlobalLoad<0>(num_loop % num_loop_per_scale == 0);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, NumScaleKBlock, 1>{}([&](auto kscale0) {
c_scale_struct.Clear();
static_for<0, KRepeat / NumScaleKBlock, 1>{}([&](auto k0) {
static_for<0, KInner, 1>{}([&](auto k_inner) {
vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
static_ford<Sequence<MRepeat, NRepeat, NumScaleKBlock>>{}([&](auto mnk) {
constexpr auto m0 = Number<mnk[Number<0>{}]>{};
constexpr auto n0 = Number<mnk[Number<1>{}]>{};
constexpr auto kscale0 = Number<mnk[Number<2>{}]>{};
c_scale_struct.Clear();
static_ford<Sequence<KRepeat / NumScaleKBlock, KInner>>{}([&](auto kk_id) {
constexpr auto k0 = Number<kk_id[Number<0>{}]>{};
constexpr auto k_inner = Number<kk_id[Number<1>{}]>{};
vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
constexpr index_t k_index =
kscale0 * (KRepeat / NumScaleKBlock) + k0;
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(Number<kk / A_K1>{},
m0,
k_index,
I0,
I0,
I0,
Number<kk % A_K1>{}))>{}];
});
static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
constexpr index_t k_index =
kscale0 * (KRepeat / NumScaleKBlock) + k0;
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(Number<kk / B_K1>{},
n0,
k_index,
I0,
I0,
I0,
Number<kk % B_K1>{}))>{}];
});
using wmma_input_type_a =
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
using wmma_input_type_b =
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
wmma_gemm.Run(
a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(),
c_scale_struct.c_thread_buf_per_scale.GetVectorTypeReference(
Number<0>{}));
});
});
c_scale_struct.template UpdateCThreadBuf<kscale0, m0, n0>(c_thread_buf);
static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
constexpr index_t k_index = kscale0 * (KRepeat / NumScaleKBlock) + k0;
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(Number<kk / A_K1>{},
m0,
k_index,
I0,
I0,
I0,
Number<kk % A_K1>{}))>{}];
});
static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
constexpr index_t k_index = kscale0 * (KRepeat / NumScaleKBlock) + k0;
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(Number<kk / B_K1>{},
n0,
k_index,
I0,
I0,
I0,
Number<kk % B_K1>{}))>{}];
});
using wmma_input_type_a =
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
using wmma_input_type_b =
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
wmma_gemm.Run(
a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(),
c_scale_struct.c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
});
c_scale_struct.template UpdateCThreadBuf<kscale0, m0, n0>(c_thread_buf);
});
c_scale_struct.Load(a_scale_struct, b_scale_struct);
@@ -891,58 +835,54 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
// Tail, always perform.
{
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, NumScaleKBlock, 1>{}([&](auto kscale0) {
c_scale_struct.Clear();
static_for<0, KRepeat / NumScaleKBlock, 1>{}([&](auto k0) {
vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
static_for<0, KInner, 1>{}([&](auto k_inner) {
static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
constexpr index_t k_index =
kscale0 * (KRepeat / NumScaleKBlock) + k0;
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(Number<kk / A_K1>{},
m0,
k_index,
I0,
I0,
I0,
Number<kk % A_K1>{}))>{}];
});
static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
constexpr index_t k_index =
kscale0 * (KRepeat / NumScaleKBlock) + k0;
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(Number<kk / B_K1>{},
n0,
k_index,
I0,
I0,
I0,
Number<kk % B_K1>{}))>{}];
});
using wmma_input_type_a =
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
using wmma_input_type_b =
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
wmma_gemm.Run(
a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(),
c_scale_struct.c_thread_buf_per_scale.GetVectorTypeReference(
Number<0>{}));
});
});
c_scale_struct.template UpdateCThreadBuf<kscale0, m0, n0>(c_thread_buf);
static_ford<Sequence<MRepeat, NRepeat, NumScaleKBlock>>{}([&](auto mnk) {
constexpr auto m0 = Number<mnk[Number<0>{}]>{};
constexpr auto n0 = Number<mnk[Number<1>{}]>{};
constexpr auto kscale0 = Number<mnk[Number<2>{}]>{};
c_scale_struct.Clear();
static_ford<Sequence<KRepeat / NumScaleKBlock, KInner>>{}([&](auto kk_id) {
constexpr auto k0 = Number<kk_id[Number<0>{}]>{};
constexpr auto k_inner = Number<kk_id[Number<1>{}]>{};
vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
constexpr index_t k_index = kscale0 * (KRepeat / NumScaleKBlock) + k0;
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(Number<kk / A_K1>{},
m0,
k_index,
I0,
I0,
I0,
Number<kk % A_K1>{}))>{}];
});
static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
constexpr index_t k_index = kscale0 * (KRepeat / NumScaleKBlock) + k0;
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(Number<kk / B_K1>{},
n0,
k_index,
I0,
I0,
I0,
Number<kk % B_K1>{}))>{}];
});
using wmma_input_type_a =
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
using wmma_input_type_b =
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
wmma_gemm.Run(
a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(),
c_scale_struct.c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
});
c_scale_struct.template UpdateCThreadBuf<kscale0, m0, n0>(c_thread_buf);
});
// Let's leak last WMMA block to epilogue region, cover the potential lds-shuffle
// latency

View File

@@ -512,22 +512,22 @@ struct BlockwiseGemmXdlops_pipeline_v4
// Local prefetch 1th, Fill Ping Reg
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf.At(I0),
a_thread_desc_,
make_tuple(m0, I0, k, I0),
a_thread_bufs(I0));
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf.At(I0),
b_thread_desc_,
make_tuple(n0, I0, k, I0),
b_thread_bufs(I0));
});
static_ford<Sequence<KRepeat, MRepeat>>{}([&](auto km) {
constexpr auto k = Number<km[Number<0>{}]>{};
constexpr auto m0 = Number<km[Number<1>{}]>{};
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf.At(I0),
a_thread_desc_,
make_tuple(m0, I0, k, I0),
a_thread_bufs(I0));
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf.At(I0),
b_thread_desc_,
make_tuple(n0, I0, k, I0),
b_thread_bufs(I0));
});
});
@@ -566,22 +566,22 @@ struct BlockwiseGemmXdlops_pipeline_v4
// DS_READ: Pong LDS to Pong Reg
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf.At(PongP1{}),
a_thread_desc_,
make_tuple(m0, I0, k, I0),
a_thread_bufs(PongP1{}));
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf.At(PongP1{}),
b_thread_desc_,
make_tuple(n0, I0, k, I0),
b_thread_bufs(PongP1{}));
});
static_ford<Sequence<KRepeat, MRepeat>>{}([&](auto km) {
constexpr auto k = Number<km[Number<0>{}]>{};
constexpr auto m0 = Number<km[Number<1>{}]>{};
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf.At(PongP1{}),
a_thread_desc_,
make_tuple(m0, I0, k, I0),
a_thread_bufs(PongP1{}));
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf.At(PongP1{}),
b_thread_desc_,
make_tuple(n0, I0, k, I0),
b_thread_bufs(PongP1{}));
});
});
@@ -594,33 +594,31 @@ struct BlockwiseGemmXdlops_pipeline_v4
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<FloatAB, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec;
static_ford<Sequence<KRepeat, MRepeat, NRepeat>>{}([&](auto kmn) {
constexpr auto k0 = Number<kmn[Number<0>{}]>{};
constexpr auto m0 = Number<kmn[Number<1>{}]>{};
constexpr auto n0 = Number<kmn[Number<2>{}]>{};
vector_type<FloatAB, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<FloatAB>()(ik) =
a_thread_bufs[PingP1{}][Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<FloatAB>()(ik) =
b_thread_bufs[PingP1{}][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<FloatAB>()(ik) =
a_thread_bufs[PingP1{}][Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<FloatAB>()(ik) =
b_thread_bufs[PingP1{}][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
HotLoopScheduler();
@@ -634,22 +632,22 @@ struct BlockwiseGemmXdlops_pipeline_v4
// DS_READ: Ping LDS to Ping Reg
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf.At(PongP2{}),
a_thread_desc_,
make_tuple(m0, I0, k, I0),
a_thread_bufs(PongP2{}));
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf.At(PongP2{}),
b_thread_desc_,
make_tuple(n0, I0, k, I0),
b_thread_bufs(PongP2{}));
});
static_ford<Sequence<KRepeat, MRepeat>>{}([&](auto km) {
constexpr auto k = Number<km[Number<0>{}]>{};
constexpr auto m0 = Number<km[Number<1>{}]>{};
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf.At(PongP2{}),
a_thread_desc_,
make_tuple(m0, I0, k, I0),
a_thread_bufs(PongP2{}));
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf.At(PongP2{}),
b_thread_desc_,
make_tuple(n0, I0, k, I0),
b_thread_bufs(PongP2{}));
});
});
@@ -662,33 +660,31 @@ struct BlockwiseGemmXdlops_pipeline_v4
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<FloatAB, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec;
static_ford<Sequence<KRepeat, MRepeat, NRepeat>>{}([&](auto kmn) {
constexpr auto k0 = Number<kmn[Number<0>{}]>{};
constexpr auto m0 = Number<kmn[Number<1>{}]>{};
constexpr auto n0 = Number<kmn[Number<2>{}]>{};
vector_type<FloatAB, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<FloatAB>()(ik) =
a_thread_bufs[PingP2{}][Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<FloatAB>()(ik) =
b_thread_bufs[PingP2{}][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<FloatAB>()(ik) =
a_thread_bufs[PingP2{}][Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<FloatAB>()(ik) =
b_thread_bufs[PingP2{}][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
HotLoopScheduler();
@@ -708,54 +704,52 @@ struct BlockwiseGemmXdlops_pipeline_v4
// DS_READ: Pong LDS to Pong Reg
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf.At(PongP1{}),
a_thread_desc_,
make_tuple(m0, I0, k, I0),
a_thread_bufs(PongP1{}));
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf.At(PongP1{}),
b_thread_desc_,
make_tuple(n0, I0, k, I0),
b_thread_bufs(PongP1{}));
});
static_ford<Sequence<KRepeat, MRepeat>>{}([&](auto km) {
constexpr auto k = Number<km[Number<0>{}]>{};
constexpr auto m0 = Number<km[Number<1>{}]>{};
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf.At(PongP1{}),
a_thread_desc_,
make_tuple(m0, I0, k, I0),
a_thread_bufs(PongP1{}));
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf.At(PongP1{}),
b_thread_desc_,
make_tuple(n0, I0, k, I0),
b_thread_bufs(PongP1{}));
});
});
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(PingP1{}));
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(PingP1{}));
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<FloatAB, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec;
static_ford<Sequence<KRepeat, MRepeat, NRepeat>>{}([&](auto kmn) {
constexpr auto k0 = Number<kmn[Number<0>{}]>{};
constexpr auto m0 = Number<kmn[Number<1>{}]>{};
constexpr auto n0 = Number<kmn[Number<2>{}]>{};
vector_type<FloatAB, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<FloatAB>()(ik) =
a_thread_bufs[PingP1{}][Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<FloatAB>()(ik) =
b_thread_bufs[PingP1{}][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<FloatAB>()(ik) =
a_thread_bufs[PingP1{}][Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<FloatAB>()(ik) =
b_thread_bufs[PingP1{}][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
TailScheduler<1>();
@@ -769,82 +763,78 @@ struct BlockwiseGemmXdlops_pipeline_v4
// DS_READ: Ping LDS to Ping Reg
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf.At(PongP2{}),
a_thread_desc_,
make_tuple(m0, I0, k, I0),
a_thread_bufs(PongP2{}));
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf.At(PongP2{}),
b_thread_desc_,
make_tuple(n0, I0, k, I0),
b_thread_bufs(PongP2{}));
});
static_ford<Sequence<KRepeat, MRepeat>>{}([&](auto km) {
constexpr auto k = Number<km[Number<0>{}]>{};
constexpr auto m0 = Number<km[Number<1>{}]>{};
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf.At(PongP2{}),
a_thread_desc_,
make_tuple(m0, I0, k, I0),
a_thread_bufs(PongP2{}));
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf.At(PongP2{}),
b_thread_desc_,
make_tuple(n0, I0, k, I0),
b_thread_bufs(PongP2{}));
});
});
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<FloatAB, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec;
static_ford<Sequence<KRepeat, MRepeat, NRepeat>>{}([&](auto kmn) {
constexpr auto k0 = Number<kmn[Number<0>{}]>{};
constexpr auto m0 = Number<kmn[Number<1>{}]>{};
constexpr auto n0 = Number<kmn[Number<2>{}]>{};
vector_type<FloatAB, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<FloatAB>()(ik) =
a_thread_bufs[PingP2{}][Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<FloatAB>()(ik) =
b_thread_bufs[PingP2{}][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<FloatAB>()(ik) =
a_thread_bufs[PingP2{}][Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<FloatAB>()(ik) =
b_thread_bufs[PingP2{}][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
TailScheduler<2>();
__builtin_amdgcn_sched_barrier(0);
static_for<0, KRepeat, 1>{}([&](auto k) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<FloatAB, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec;
static_ford<Sequence<KRepeat, MRepeat, NRepeat>>{}([&](auto kmn) {
constexpr auto k = Number<kmn[Number<0>{}]>{};
constexpr auto m0 = Number<kmn[Number<1>{}]>{};
constexpr auto n0 = Number<kmn[Number<2>{}]>{};
vector_type<FloatAB, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<FloatAB>()(ik) =
a_thread_bufs[PongP2{}][Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k, ik))>{}];
b_thread_vec.template AsType<FloatAB>()(ik) =
b_thread_bufs[PongP2{}][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k, ik))>{}];
});
using mfma_input_type =
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<FloatAB>()(ik) =
a_thread_bufs[PongP2{}][Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k, ik))>{}];
b_thread_vec.template AsType<FloatAB>()(ik) =
b_thread_bufs[PongP2{}][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k, ik))>{}];
});
using mfma_input_type =
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
// 64 v_mfma
@@ -860,51 +850,49 @@ struct BlockwiseGemmXdlops_pipeline_v4
// DS_READ: Pong LDS to Pong Reg
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf.At(PongP1{}),
a_thread_desc_,
make_tuple(m0, I0, k, I0),
a_thread_bufs(PongP1{}));
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf.At(PongP1{}),
b_thread_desc_,
make_tuple(n0, I0, k, I0),
b_thread_bufs(PongP1{}));
});
static_ford<Sequence<KRepeat, MRepeat>>{}([&](auto km) {
constexpr auto k = Number<km[Number<0>{}]>{};
constexpr auto m0 = Number<km[Number<1>{}]>{};
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf.At(PongP1{}),
a_thread_desc_,
make_tuple(m0, I0, k, I0),
a_thread_bufs(PongP1{}));
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf.At(PongP1{}),
b_thread_desc_,
make_tuple(n0, I0, k, I0),
b_thread_bufs(PongP1{}));
});
});
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<FloatAB, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec;
static_ford<Sequence<KRepeat, MRepeat, NRepeat>>{}([&](auto kmn) {
constexpr auto k0 = Number<kmn[Number<0>{}]>{};
constexpr auto m0 = Number<kmn[Number<1>{}]>{};
constexpr auto n0 = Number<kmn[Number<2>{}]>{};
vector_type<FloatAB, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<FloatAB>()(ik) =
a_thread_bufs[PingP1{}][Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<FloatAB>()(ik) =
b_thread_bufs[PingP1{}][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<FloatAB>()(ik) =
a_thread_bufs[PingP1{}][Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<FloatAB>()(ik) =
b_thread_bufs[PingP1{}][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
TailScheduler<2>();
@@ -916,32 +904,30 @@ struct BlockwiseGemmXdlops_pipeline_v4
// DS_WRITE: To Pong LDS
// DS_READ: Ping LDS to Ping Reg
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<FloatAB, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec;
static_ford<Sequence<KRepeat, MRepeat, NRepeat>>{}([&](auto kmn) {
constexpr auto k0 = Number<kmn[Number<0>{}]>{};
constexpr auto m0 = Number<kmn[Number<1>{}]>{};
constexpr auto n0 = Number<kmn[Number<2>{}]>{};
vector_type<FloatAB, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<FloatAB>()(ik) =
a_thread_bufs[PingP2{}][Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<FloatAB>()(ik) =
b_thread_bufs[PingP2{}][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<FloatAB>()(ik) =
a_thread_bufs[PingP2{}][Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<FloatAB>()(ik) =
b_thread_bufs[PingP2{}][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
// 64 v_mfma

View File

@@ -275,15 +275,15 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_bdequant_v1<BlockGemmPipelineSch
// Local prefetch A1
block_sync_lds();
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, k0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, I0),
a_thread_buf);
});
static_ford<Sequence<MRepeat, KRepeat>>{}([&](auto mk) {
constexpr auto m0 = Number<mk[Number<0>{}]>{};
constexpr auto k0 = Number<mk[Number<1>{}]>{};
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, k0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, I0),
a_thread_buf);
});
// B VGPR->VGPR dequant
b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1,
@@ -318,47 +318,43 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_bdequant_v1<BlockGemmPipelineSch
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_ford<Sequence<MRepeat, NRepeat, KRepeat>>{}([&](auto mnk) {
constexpr auto m0 = Number<mnk[Number<0>{}]>{};
constexpr auto n0 = Number<mnk[Number<1>{}]>{};
constexpr auto k0 = Number<mnk[Number<2>{}]>{};
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_dequant_bufs[mfma_reg_buf]
[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_dequant_bufs[mfma_reg_buf][Number<
b_thread_desc_.CalculateOffset(make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
block_sync_lds();
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, k0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, I0),
a_thread_buf);
});
static_ford<Sequence<MRepeat, KRepeat>>{}([&](auto mk) {
constexpr auto m0 = Number<mk[Number<0>{}]>{};
constexpr auto k0 = Number<mk[Number<1>{}]>{};
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, k0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, I0),
a_thread_buf);
});
// B VGPR->VGPR dequant
b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1,
@@ -390,45 +386,42 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_bdequant_v1<BlockGemmPipelineSch
block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_ford<Sequence<MRepeat, NRepeat, KRepeat>>{}([&](auto mnk) {
constexpr auto m0 = Number<mnk[Number<0>{}]>{};
constexpr auto n0 = Number<mnk[Number<1>{}]>{};
constexpr auto k0 = Number<mnk[Number<2>{}]>{};
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_dequant_bufs[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) = b_thread_dequant_bufs
[I0][Number<b_thread_desc_.CalculateOffset(make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
block_sync_lds();
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, k0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, I0),
a_thread_buf);
});
static_ford<Sequence<MRepeat, KRepeat>>{}([&](auto mk) {
constexpr auto m0 = Number<mk[Number<0>{}]>{};
constexpr auto k0 = Number<mk[Number<1>{}]>{};
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, k0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, I0),
a_thread_buf);
});
// B VGPR->VGPR dequant
b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1,
@@ -440,32 +433,29 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_bdequant_v1<BlockGemmPipelineSch
__builtin_amdgcn_sched_barrier(0);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_ford<Sequence<MRepeat, NRepeat, KRepeat>>{}([&](auto mnk) {
constexpr auto m0 = Number<mnk[Number<0>{}]>{};
constexpr auto n0 = Number<mnk[Number<1>{}]>{};
constexpr auto k0 = Number<mnk[Number<2>{}]>{};
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_dequant_bufs[I1][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) = b_thread_dequant_bufs
[I1][Number<b_thread_desc_.CalculateOffset(make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
// Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle
// latency
@@ -473,32 +463,29 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_bdequant_v1<BlockGemmPipelineSch
}
else
{
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_ford<Sequence<MRepeat, NRepeat, KRepeat>>{}([&](auto mnk) {
constexpr auto m0 = Number<mnk[Number<0>{}]>{};
constexpr auto n0 = Number<mnk[Number<1>{}]>{};
constexpr auto k0 = Number<mnk[Number<2>{}]>{};
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_dequant_bufs[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) = b_thread_dequant_bufs
[I0][Number<b_thread_desc_.CalculateOffset(make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
}
}

View File

@@ -593,39 +593,38 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_bdequant_v3<BlockGemmPipelineSch
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
}
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_ford<Sequence<KRepeat, NRepeat>>{}([&](auto kn) {
constexpr auto k0 = Number<kn[Number<0>{}]>{};
constexpr auto n0 = Number<kn[Number<1>{}]>{};
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple((m0 + HotloopLocalBufSwitch * mfma_reg_buf) %
2,
I0,
I0,
k0,
I0,
ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_dequant_bufs[mfma_reg_buf]
[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple((m0 + HotloopLocalBufSwitch * mfma_reg_buf) % 2,
I0,
I0,
k0,
I0,
ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_dequant_bufs[mfma_reg_buf]
[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
if constexpr(m0.value == MRepeat - 1)
@@ -710,30 +709,30 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_bdequant_v3<BlockGemmPipelineSch
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1));
}
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_ford<Sequence<KRepeat, NRepeat>>{}([&](auto kn) {
constexpr auto k0 = Number<kn[Number<0>{}]>{};
constexpr auto n0 = Number<kn[Number<1>{}]>{};
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0 % 2, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_dequant_bufs[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0 % 2, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_dequant_bufs[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
if constexpr(m0.value == MRepeat - 1)
@@ -781,30 +780,30 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_bdequant_v3<BlockGemmPipelineSch
});
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_ford<Sequence<KRepeat, NRepeat>>{}([&](auto kn) {
constexpr auto k0 = Number<kn[Number<0>{}]>{};
constexpr auto n0 = Number<kn[Number<1>{}]>{};
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple(
(m0 + HotloopLocalBufSwitch) % 2, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_dequant_bufs[I1][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple(
(m0 + HotloopLocalBufSwitch) % 2, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_dequant_bufs[I1][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
if constexpr(m0.value != (MRepeat - 1))
@@ -837,30 +836,30 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_bdequant_v3<BlockGemmPipelineSch
else
{
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_ford<Sequence<KRepeat, NRepeat>>{}([&](auto kn) {
constexpr auto k0 = Number<kn[Number<0>{}]>{};
constexpr auto n0 = Number<kn[Number<1>{}]>{};
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0 % 2, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_dequant_bufs[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0 % 2, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_dequant_bufs[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
if constexpr(m0.value != (MRepeat - 1))

View File

@@ -289,15 +289,15 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1<
// Local prefetch A1
block_sync_lds();
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, k0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, I0),
a_thread_buf);
});
static_ford<Sequence<MRepeat, KRepeat>>{}([&](auto mk) {
constexpr auto m0 = Number<mk[Number<0>{}]>{};
constexpr auto k0 = Number<mk[Number<1>{}]>{};
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, k0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, I0),
a_thread_buf);
});
// B VGPR->VGPR dequant
b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1,
@@ -345,57 +345,51 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1<
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec_up;
static_ford<Sequence<MRepeat, NRepeat, KRepeat>>{}([&](auto mnk) {
constexpr auto m0 = Number<mnk[Number<0>{}]>{};
constexpr auto n0 = Number<mnk[Number<1>{}]>{};
constexpr auto k0 = Number<mnk[Number<2>{}]>{};
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec_up;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_dequant_bufs[mfma_reg_buf]
[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
b_thread_dequant_bufs_up
[mfma_reg_buf][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec_up.template AsType<mfma_input_type>(),
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
});
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_dequant_bufs[mfma_reg_buf][Number<
b_thread_desc_.CalculateOffset(make_tuple(n0, I0, k0, ik))>{}];
b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
b_thread_dequant_bufs_up[mfma_reg_buf][Number<
b_thread_desc_.CalculateOffset(make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec_up.template AsType<mfma_input_type>(),
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
});
block_sync_lds();
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, k0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, I0),
a_thread_buf);
});
static_ford<Sequence<MRepeat, KRepeat>>{}([&](auto mk) {
constexpr auto m0 = Number<mk[Number<0>{}]>{};
constexpr auto k0 = Number<mk[Number<1>{}]>{};
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, k0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, I0),
a_thread_buf);
});
// B VGPR->VGPR dequant
b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1,
@@ -439,52 +433,49 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1<
block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec_up;
static_ford<Sequence<MRepeat, NRepeat, KRepeat>>{}([&](auto mnk) {
constexpr auto m0 = Number<mnk[Number<0>{}]>{};
constexpr auto n0 = Number<mnk[Number<1>{}]>{};
constexpr auto k0 = Number<mnk[Number<2>{}]>{};
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec_up;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_dequant_bufs[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
b_thread_dequant_bufs_up[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec_up.template AsType<mfma_input_type>(),
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
});
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) = b_thread_dequant_bufs
[I0][Number<b_thread_desc_.CalculateOffset(make_tuple(n0, I0, k0, ik))>{}];
b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
b_thread_dequant_bufs_up[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec_up.template AsType<mfma_input_type>(),
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
});
block_sync_lds();
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, k0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, I0),
a_thread_buf);
});
static_ford<Sequence<MRepeat, KRepeat>>{}([&](auto mk) {
constexpr auto m0 = Number<mk[Number<0>{}]>{};
constexpr auto k0 = Number<mk[Number<1>{}]>{};
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, k0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, I0),
a_thread_buf);
});
// B VGPR->VGPR dequant
b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1,
@@ -502,39 +493,36 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1<
b_thread_dequant_bufs_up(I1));
__builtin_amdgcn_sched_barrier(0);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec_up;
static_ford<Sequence<MRepeat, NRepeat, KRepeat>>{}([&](auto mnk) {
constexpr auto m0 = Number<mnk[Number<0>{}]>{};
constexpr auto n0 = Number<mnk[Number<1>{}]>{};
constexpr auto k0 = Number<mnk[Number<2>{}]>{};
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec_up;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_dequant_bufs[I1][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
b_thread_dequant_bufs_up[I1][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec_up.template AsType<mfma_input_type>(),
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
});
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) = b_thread_dequant_bufs
[I1][Number<b_thread_desc_.CalculateOffset(make_tuple(n0, I0, k0, ik))>{}];
b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
b_thread_dequant_bufs_up[I1][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec_up.template AsType<mfma_input_type>(),
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
});
// Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle
// latency
@@ -542,39 +530,36 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1<
}
else
{
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec_up;
static_ford<Sequence<MRepeat, NRepeat, KRepeat>>{}([&](auto mnk) {
constexpr auto m0 = Number<mnk[Number<0>{}]>{};
constexpr auto n0 = Number<mnk[Number<1>{}]>{};
constexpr auto k0 = Number<mnk[Number<2>{}]>{};
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec_up;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_dequant_bufs[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
b_thread_dequant_bufs_up[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec_up.template AsType<mfma_input_type>(),
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
});
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) = b_thread_dequant_bufs
[I0][Number<b_thread_desc_.CalculateOffset(make_tuple(n0, I0, k0, ik))>{}];
b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
b_thread_dequant_bufs_up[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec_up.template AsType<mfma_input_type>(),
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
});
}
}

View File

@@ -298,17 +298,16 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v1<BlockGemmPipelineSch
// Local prefetch A1
block_sync_lds();
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
});
static_ford<Sequence<MRepeat, KRepeat, KGroup>>{}([&](auto mkg) {
constexpr auto m0 = Number<mkg[Number<0>{}]>{};
constexpr auto k0 = Number<mkg[Number<1>{}]>{};
constexpr auto kg0 = Number<mkg[Number<2>{}]>{};
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
// Initialize C
@@ -342,60 +341,53 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v1<BlockGemmPipelineSch
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec_up;
static_ford<Sequence<MRepeat, NRepeat, KRepeat>>{}([&](auto mnk) {
constexpr auto m0 = Number<mnk[Number<0>{}]>{};
constexpr auto n0 = Number<mnk[Number<1>{}]>{};
constexpr auto k0 = Number<mnk[Number<2>{}]>{};
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec_up;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[mfma_reg_buf]
[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
b_thread_bufs_up[mfma_reg_buf]
[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec_up.template AsType<mfma_input_type>(),
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
});
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[mfma_reg_buf][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
b_thread_bufs_up[mfma_reg_buf][Number<
b_thread_desc_.CalculateOffset(make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec_up.template AsType<mfma_input_type>(),
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
});
block_sync_lds();
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
});
static_ford<Sequence<MRepeat, KRepeat, KGroup>>{}([&](auto mkg) {
constexpr auto m0 = Number<mkg[Number<0>{}]>{};
constexpr auto k0 = Number<mkg[Number<1>{}]>{};
constexpr auto kg0 = Number<mkg[Number<2>{}]>{};
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
HotLoopScheduler();
@@ -425,93 +417,83 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v1<BlockGemmPipelineSch
block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec_up;
static_ford<Sequence<MRepeat, NRepeat, KRepeat>>{}([&](auto mnk) {
constexpr auto m0 = Number<mnk[Number<0>{}]>{};
constexpr auto n0 = Number<mnk[Number<1>{}]>{};
constexpr auto k0 = Number<mnk[Number<2>{}]>{};
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec_up;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
b_thread_bufs_up[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec_up.template AsType<mfma_input_type>(),
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
});
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) = b_thread_bufs
[I0][Number<b_thread_desc_.CalculateOffset(make_tuple(n0, I0, k0, ik))>{}];
b_thread_vec_up.template AsType<ComputeDataType>()(ik) = b_thread_bufs_up
[I0][Number<b_thread_desc_.CalculateOffset(make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec_up.template AsType<mfma_input_type>(),
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
});
block_sync_lds();
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
});
static_ford<Sequence<MRepeat, KRepeat, KGroup>>{}([&](auto mkg) {
constexpr auto m0 = Number<mkg[Number<0>{}]>{};
constexpr auto k0 = Number<mkg[Number<1>{}]>{};
constexpr auto kg0 = Number<mkg[Number<2>{}]>{};
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
__builtin_amdgcn_sched_barrier(0);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec_up;
static_ford<Sequence<MRepeat, NRepeat, KRepeat>>{}([&](auto mnk) {
constexpr auto m0 = Number<mnk[Number<0>{}]>{};
constexpr auto n0 = Number<mnk[Number<1>{}]>{};
constexpr auto k0 = Number<mnk[Number<2>{}]>{};
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec_up;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[I1][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
b_thread_bufs_up[I1][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec_up.template AsType<mfma_input_type>(),
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
});
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) = b_thread_bufs
[I1][Number<b_thread_desc_.CalculateOffset(make_tuple(n0, I0, k0, ik))>{}];
b_thread_vec_up.template AsType<ComputeDataType>()(ik) = b_thread_bufs_up
[I1][Number<b_thread_desc_.CalculateOffset(make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec_up.template AsType<mfma_input_type>(),
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
});
// Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle
// latency
@@ -519,39 +501,35 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v1<BlockGemmPipelineSch
}
else if constexpr(TailNum == TailNumber::Odd)
{
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec_up;
static_ford<Sequence<MRepeat, NRepeat, KRepeat>>{}([&](auto mnk) {
constexpr auto m0 = Number<mnk[Number<0>{}]>{};
constexpr auto n0 = Number<mnk[Number<1>{}]>{};
constexpr auto k0 = Number<mnk[Number<2>{}]>{};
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec_up;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
b_thread_bufs_up[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec_up.template AsType<mfma_input_type>(),
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
});
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) = b_thread_bufs
[I0][Number<b_thread_desc_.CalculateOffset(make_tuple(n0, I0, k0, ik))>{}];
b_thread_vec_up.template AsType<ComputeDataType>()(ik) = b_thread_bufs_up
[I0][Number<b_thread_desc_.CalculateOffset(make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec_up.template AsType<mfma_input_type>(),
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
});
}
}

View File

@@ -255,54 +255,53 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3<BlockGemmPipelineSch
constexpr auto buffer_load_issue_point_a = num_mfma_perstage >= 3 ? 1 : 0;
// B global read
static_for<0, buffer_load_b_stages, 1>{}([&](auto i) {
static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
static_ford<Sequence<buffer_load_b_stages, num_mfma_perstage>>{}([&](auto ii) {
constexpr auto i = Number<ii[Number<0>{}]>{};
constexpr auto imfma = Number<ii[Number<1>{}]>{};
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
if constexpr(((i < buffer_load_stages_more) &&
(imfma % buffer_load_issue_point_interval_more ==
buffer_load_issue_point_b)) ||
((i >= buffer_load_stages_more) &&
(imfma % buffer_load_issue_point_interval_less ==
buffer_load_issue_point_b)))
{
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
if constexpr(((i < buffer_load_stages_more) &&
(imfma % buffer_load_issue_point_interval_more ==
buffer_load_issue_point_b)) ||
((i >= buffer_load_stages_more) &&
(imfma % buffer_load_issue_point_interval_less ==
buffer_load_issue_point_b)))
{
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
{
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
}
});
if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
{
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
}
});
// A global read + A local write
static_for<0, buffer_load_a_stages, 1>{}([&](auto i) {
static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
if constexpr((((i + buffer_load_b_stages) < buffer_load_stages_more) &&
(imfma % buffer_load_issue_point_interval_more ==
ds_write_issue_point)) ||
(((i + buffer_load_b_stages) >= buffer_load_stages_more) &&
(imfma % buffer_load_issue_point_interval_less ==
ds_write_issue_point)))
{
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
}
if constexpr((((i + buffer_load_b_stages) < buffer_load_stages_more) &&
(imfma % buffer_load_issue_point_interval_more ==
buffer_load_issue_point_a)) ||
(((i + buffer_load_b_stages) >= buffer_load_stages_more) &&
(imfma % buffer_load_issue_point_interval_less ==
buffer_load_issue_point_a)))
{
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
{
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
}
});
static_ford<Sequence<buffer_load_a_stages, num_mfma_perstage>>{}([&](auto ii) {
constexpr auto i = Number<ii[Number<0>{}]>{};
constexpr auto imfma = Number<ii[Number<1>{}]>{};
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
if constexpr((((i + buffer_load_b_stages) < buffer_load_stages_more) &&
(imfma % buffer_load_issue_point_interval_more ==
ds_write_issue_point)) ||
(((i + buffer_load_b_stages) >= buffer_load_stages_more) &&
(imfma % buffer_load_issue_point_interval_less == ds_write_issue_point)))
{
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
}
if constexpr((((i + buffer_load_b_stages) < buffer_load_stages_more) &&
(imfma % buffer_load_issue_point_interval_more ==
buffer_load_issue_point_a)) ||
(((i + buffer_load_b_stages) >= buffer_load_stages_more) &&
(imfma % buffer_load_issue_point_interval_less ==
buffer_load_issue_point_a)))
{
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
{
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
}
});
// lds synchronization, prefetch next loop local A
@@ -511,17 +510,16 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3<BlockGemmPipelineSch
// Local prefetch A1
block_sync_lds();
static_for<0, 2, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf.At(I0),
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
});
static_ford<Sequence<2, KRepeat, KGroup>>{}([&](auto mkk) {
constexpr auto m0 = Number<mkk[Number<0>{}]>{};
constexpr auto k0 = Number<mkk[Number<1>{}]>{};
constexpr auto kg0 = Number<mkk[Number<2>{}]>{};
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf.At(I0),
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
// Initialize C
@@ -554,130 +552,129 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3<BlockGemmPipelineSch
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec_up;
static_ford<Sequence<KRepeat, NRepeat>>{}([&](auto kn) {
constexpr auto k0 = Number<kn[Number<0>{}]>{};
constexpr auto n0 = Number<kn[Number<1>{}]>{};
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec_up;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple((m0 + HotloopLocalBufSwitch * mfma_reg_buf) %
2,
I0,
I0,
k0,
I0,
ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[mfma_reg_buf]
[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple((m0 + HotloopLocalBufSwitch * mfma_reg_buf) % 2,
I0,
I0,
k0,
I0,
ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[mfma_reg_buf]
[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
b_thread_bufs_up[mfma_reg_buf]
[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec_up.template AsType<mfma_input_type>(),
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
b_thread_bufs_up[mfma_reg_buf]
[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec_up.template AsType<mfma_input_type>(),
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
});
if constexpr(m0.value == MRepeat - 2)
{
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 2) % MRepeat>{},
I0,
I0,
Number<k0 * KGroup + kg0>{},
I0,
I0),
a_block_buf.At(local_read_buf),
a_thread_desc_,
make_tuple(
Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) %
2>{},
I0,
I0,
k0,
I0,
Number<kg0 * A_K1>{}),
a_thread_buf);
});
static_ford<Sequence<KRepeat, KGroup>>{}([&](auto kk) {
constexpr auto k0 = Number<kk[Number<0>{}]>{};
constexpr auto kg0 = Number<kk[Number<1>{}]>{};
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 2) % MRepeat>{},
I0,
I0,
Number<k0 * KGroup + kg0>{},
I0,
I0),
a_block_buf.At(local_read_buf),
a_thread_desc_,
make_tuple(
Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) %
2>{},
I0,
I0,
k0,
I0,
Number<kg0 * A_K1>{}),
a_thread_buf);
});
}
else if constexpr(m0.value == (MRepeat - 1))
{
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 2) % MRepeat>{},
I0,
I0,
Number<k0 * KGroup + kg0>{},
I0,
I0),
a_block_buf.At(local_read_buf),
a_thread_desc_,
make_tuple(
Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) %
2>{},
I0,
I0,
k0,
I0,
Number<kg0 * A_K1>{}),
a_thread_buf);
});
static_ford<Sequence<KRepeat, KGroup>>{}([&](auto kk) {
constexpr auto k0 = Number<kk[Number<0>{}]>{};
constexpr auto kg0 = Number<kk[Number<1>{}]>{};
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 2) % MRepeat>{},
I0,
I0,
Number<k0 * KGroup + kg0>{},
I0,
I0),
a_block_buf.At(local_read_buf),
a_thread_desc_,
make_tuple(
Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) %
2>{},
I0,
I0,
k0,
I0,
Number<kg0 * A_K1>{}),
a_thread_buf);
});
}
else
{
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 2) % MRepeat>{},
I0,
I0,
Number<k0 * KGroup + kg0>{},
I0,
I0),
a_block_buf.At(mfma_reg_buf),
a_thread_desc_,
make_tuple(
Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) %
2>{},
I0,
I0,
k0,
I0,
Number<kg0 * A_K1>{}),
a_thread_buf);
});
static_ford<Sequence<KRepeat, KGroup>>{}([&](auto kk) {
constexpr auto k0 = Number<kk[Number<0>{}]>{};
constexpr auto kg0 = Number<kk[Number<1>{}]>{};
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 2) % MRepeat>{},
I0,
I0,
Number<k0 * KGroup + kg0>{},
I0,
I0),
a_block_buf.At(mfma_reg_buf),
a_thread_desc_,
make_tuple(
Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) %
2>{},
I0,
I0,
k0,
I0,
Number<kg0 * A_K1>{}),
a_thread_buf);
});
}
});
@@ -706,100 +703,100 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3<BlockGemmPipelineSch
b_thread_bufs_up(I1));
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1));
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec_up;
static_ford<Sequence<KRepeat, NRepeat>>{}([&](auto kn) {
constexpr auto k0 = Number<kn[Number<0>{}]>{};
constexpr auto n0 = Number<kn[Number<1>{}]>{};
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec_up;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0 % 2, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0 % 2, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
b_thread_bufs_up[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec_up.template AsType<mfma_input_type>(),
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
b_thread_bufs_up[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec_up.template AsType<mfma_input_type>(),
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
});
if constexpr(m0.value == (MRepeat - 2))
{
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 2) % MRepeat>{},
I0,
I0,
Number<k0 * KGroup + kg0>{},
I0,
I0),
a_block_buf.At(I1),
a_thread_desc_,
make_tuple(
Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
static_ford<Sequence<KRepeat, KGroup>>{}([&](auto kk) {
constexpr auto k0 = Number<kk[Number<0>{}]>{};
constexpr auto kg0 = Number<kk[Number<1>{}]>{};
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 2) % MRepeat>{},
I0,
I0,
Number<k0 * KGroup + kg0>{},
I0,
I0),
a_block_buf.At(I1),
a_thread_desc_,
make_tuple(
Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
}
else if constexpr(m0.value == MRepeat - 1)
{
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 2) % MRepeat>{},
I0,
I0,
Number<k0 * KGroup + kg0>{},
I0,
I0),
a_block_buf.At(I1),
a_thread_desc_,
make_tuple(
Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
static_ford<Sequence<KRepeat, KGroup>>{}([&](auto kk) {
constexpr auto k0 = Number<kk[Number<0>{}]>{};
constexpr auto kg0 = Number<kk[Number<1>{}]>{};
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 2) % MRepeat>{},
I0,
I0,
Number<k0 * KGroup + kg0>{},
I0,
I0),
a_block_buf.At(I1),
a_thread_desc_,
make_tuple(
Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
}
else
{
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 2) % MRepeat>{},
I0,
I0,
Number<k0 * KGroup + kg0>{},
I0,
I0),
a_block_buf.At(I0),
a_thread_desc_,
make_tuple(
Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
static_ford<Sequence<KRepeat, KGroup>>{}([&](auto kk) {
constexpr auto k0 = Number<kk[Number<0>{}]>{};
constexpr auto kg0 = Number<kk[Number<1>{}]>{};
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 2) % MRepeat>{},
I0,
I0,
Number<k0 * KGroup + kg0>{},
I0,
I0),
a_block_buf.At(I0),
a_thread_desc_,
make_tuple(
Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
}
});
@@ -807,58 +804,58 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3<BlockGemmPipelineSch
HotLoopScheduler();
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec_up;
static_ford<Sequence<KRepeat, NRepeat>>{}([&](auto kn) {
constexpr auto k0 = Number<kn[Number<0>{}]>{};
constexpr auto n0 = Number<kn[Number<1>{}]>{};
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec_up;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple(
(m0 + HotloopLocalBufSwitch) % 2, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[I1][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
b_thread_bufs_up[I1][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec_up.template AsType<mfma_input_type>(),
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple(
(m0 + HotloopLocalBufSwitch) % 2, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[I1][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
b_thread_bufs_up[I1][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec_up.template AsType<mfma_input_type>(),
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
});
if constexpr(m0.value < (MRepeat - 2))
{
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(
Number<m0 + 2>{}, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf.At(I1),
a_thread_desc_,
make_tuple(Number<(m0 + 2 + HotloopLocalBufSwitch) % 2>{},
I0,
I0,
k0,
I0,
Number<kg0 * A_K1>{}),
a_thread_buf);
});
static_ford<Sequence<KRepeat, KGroup>>{}([&](auto kk) {
constexpr auto k0 = Number<kk[Number<0>{}]>{};
constexpr auto kg0 = Number<kk[Number<1>{}]>{};
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(
Number<m0 + 2>{}, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf.At(I1),
a_thread_desc_,
make_tuple(Number<(m0 + 2 + HotloopLocalBufSwitch) % 2>{},
I0,
I0,
k0,
I0,
Number<kg0 * A_K1>{}),
a_thread_buf);
});
}
});
@@ -870,53 +867,53 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3<BlockGemmPipelineSch
else if constexpr(TailNum == TailNumber::Odd)
{
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec_up;
static_ford<Sequence<KRepeat, NRepeat>>{}([&](auto kn) {
constexpr auto k0 = Number<kn[Number<0>{}]>{};
constexpr auto n0 = Number<kn[Number<1>{}]>{};
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec_up;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0 % 2, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
b_thread_bufs_up[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec_up.template AsType<mfma_input_type>(),
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0 % 2, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
b_thread_bufs_up[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec_up.template AsType<mfma_input_type>(),
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
});
if constexpr(m0.value < (MRepeat - 2))
{
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(
Number<m0 + 2>{}, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf.At(I0),
a_thread_desc_,
make_tuple(
Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
static_ford<Sequence<KRepeat, KGroup>>{}([&](auto kk) {
constexpr auto k0 = Number<kk[Number<0>{}]>{};
constexpr auto kg0 = Number<kk[Number<1>{}]>{};
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(
Number<m0 + 2>{}, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf.At(I0),
a_thread_desc_,
make_tuple(
Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
}
});

View File

@@ -259,25 +259,26 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_gufusion_v3<
// Stage 1
// global read more
static_for<0, buffer_load_stages_more, 1>{}([&](auto /*i*/) {
static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
static_ford<Sequence<buffer_load_stages_more, num_mfma_perstage>>{}([&](auto ii) {
constexpr auto imfma = Number<ii[Number<1>{}]>{};
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
if constexpr(imfma % buffer_load_issue_point_interval_more == 0)
{
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
if constexpr(imfma % buffer_load_issue_point_interval_more == 0)
{
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
{
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
}
});
if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
{
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
}
});
// global read less
static_for<0, (num_total_stages - 2 - buffer_load_stages_more), 1>{}([&](auto /*i*/) {
static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
static_ford<
Sequence<(num_total_stages - 2 - buffer_load_stages_more), num_mfma_perstage>>{}(
[&](auto ii) {
constexpr auto imfma = Number<ii[Number<1>{}]>{};
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
if constexpr(imfma % buffer_load_issue_point_interval_less == 0)
{
@@ -288,22 +289,20 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_gufusion_v3<
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
}
});
});
// Stage 2, Sync
// lds synchronization, prefetch next loop local A
static_for<0, num_ds_read_a_prefetch_stages, 1>{}([&](auto /*i*/) {
static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
if constexpr(imfma % buffer_load_issue_point_interval_stage2 == 0)
{
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
{
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
}
});
static_ford<Sequence<num_ds_read_a_prefetch_stages, num_mfma_perstage>>{}([&](auto ii) {
constexpr auto imfma = Number<ii[Number<1>{}]>{};
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
if constexpr(imfma % buffer_load_issue_point_interval_stage2 == 0)
{
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
{
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
}
});
}
@@ -463,25 +462,24 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_gufusion_v3<
// Local prefetch 1, sync the async load
__builtin_amdgcn_s_waitcnt(async_vmcnt_encoding);
block_sync_lds();
static_for<0, LocalPrefetchStages, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k) {
constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize *
(APackedSize * KPack / xdlops_gemm.K1PerXdlops);
static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}(
[&](auto chunk) {
constexpr auto a_k_step_chunk =
k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_m3_k,
make_tuple(
I0, I0, Number<m0 % MXdlPack>{}, I0, Number<a_k_step_chunk>{}),
a_block_bufs(I0),
a_thread_desc_,
make_tuple(
I0, I0, Number<m0 % MXdlPack>{}, k, Number<chunk * KThreadChunk>{}),
a_thread_buf);
});
});
static_ford<Sequence<LocalPrefetchStages, KRepeat>>{}([&](auto mk) {
constexpr auto m0 = Number<mk[Number<0>{}]>{};
constexpr auto k = Number<mk[Number<1>{}]>{};
constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize *
(APackedSize * KPack / xdlops_gemm.K1PerXdlops);
static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}(
[&](auto chunk) {
constexpr auto a_k_step_chunk =
k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_m3_k,
make_tuple(I0, I0, Number<m0 % MXdlPack>{}, I0, Number<a_k_step_chunk>{}),
a_block_bufs(I0),
a_thread_desc_,
make_tuple(
I0, I0, Number<m0 % MXdlPack>{}, k, Number<chunk * KThreadChunk>{}),
a_thread_buf);
});
});
// Global prefetch 2
@@ -583,105 +581,97 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_gufusion_v3<
static_for<0, MRepeat, 1>{}([&](auto m0) {
constexpr auto im_major = m0 / MXdlPack;
constexpr auto im_minor = m0 % MXdlPack;
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_ford<Sequence<KRepeat, NRepeat>>{}([&](auto kn) {
constexpr auto k0 = Number<kn[Number<0>{}]>{};
constexpr auto n0 = Number<kn[Number<1>{}]>{};
constexpr auto ik_major = k0 / KXdlPack;
constexpr auto ik_minor = k0 % KXdlPack;
static_for<0, NRepeat, 1>{}([&](auto n0) {
constexpr auto in_major = n0 / NXdlPack;
constexpr auto in_minor = n0 % NXdlPack;
constexpr auto in_major = n0 / NXdlPack;
constexpr auto in_minor = n0 % NXdlPack;
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(
make_tuple(im_major, ik_major, I0));
constexpr index_t b_scale_offset =
b_scale_thread_desc.CalculateOffset(
make_tuple(in_major, ik_major, I0));
constexpr index_t a_scale_offset = a_scale_thread_desc.CalculateOffset(
make_tuple(im_major, ik_major, I0));
constexpr index_t b_scale_offset = b_scale_thread_desc.CalculateOffset(
make_tuple(in_major, ik_major, I0));
static_assert(0 < ScalesPerXdlopsRunPerThread,
"Must have at least one scale per Xdlops "
"per Thread.");
static_assert(0 < ScalesPerXdlopsRunPerThread,
"Must have at least one scale per Xdlops "
"per Thread.");
vector_type<AScaleDataType, a_scale_thread_vec_size>
a_scale_thread_vec;
vector_type<BScaleDataType, b_scale_thread_vec_size>
b_scale_thread_vec;
vector_type<BScaleDataType, b_scale_thread_vec_size>
b_scale_thread_vec_up;
vector_type<AScaleDataType, a_scale_thread_vec_size> a_scale_thread_vec;
vector_type<BScaleDataType, b_scale_thread_vec_size> b_scale_thread_vec;
vector_type<BScaleDataType, b_scale_thread_vec_size>
b_scale_thread_vec_up;
// Pack scale_thread_buf into scale_thread_vec
static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) {
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
a_scale_thread_bufs(
scale_comp_buf)[Number<a_scale_offset + s>{}];
});
// B Gate scale
static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) {
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
b_scale_thread_bufs(
scale_comp_buf)[Number<b_scale_offset + s>{}];
});
// B Up scale
static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) {
b_scale_thread_vec_up.template AsType<BScaleDataType>()(s) =
b_scale_thread_bufs_up(
scale_comp_buf)[Number<b_scale_offset + s>{}];
});
// Pack scale_thread_buf into scale_thread_vec
static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) {
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
a_scale_thread_bufs(
scale_comp_buf)[Number<a_scale_offset + s>{}];
});
// B Gate scale
static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) {
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
b_scale_thread_bufs(
scale_comp_buf)[Number<b_scale_offset + s>{}];
});
// B Up scale
static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) {
b_scale_thread_vec_up.template AsType<BScaleDataType>()(s) =
b_scale_thread_bufs_up(
scale_comp_buf)[Number<b_scale_offset + s>{}];
});
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec_up;
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec_up;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(I0, I0, im_minor, k0, ik))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(ik) = b_thread_bufs
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(I0, I0, im_minor, k0, ik))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_bufs[scale_comp_buf]
[Number<b_thread_desc_.CalculateOffset(make_tuple(
in_major, I0, in_minor, k0, ik))>{}];
b_thread_vec_up.template AsType<ComputeTypeB>()(ik) =
b_thread_bufs_up
[scale_comp_buf][Number<b_thread_desc_.CalculateOffset(
make_tuple(in_major, I0, in_minor, k0, ik))>{}];
b_thread_vec_up.template AsType<ComputeTypeB>()(ik) =
b_thread_bufs_up
[scale_comp_buf][Number<b_thread_desc_.CalculateOffset(
make_tuple(in_major, I0, in_minor, k0, ik))>{}];
});
using mfma_input_type_a =
typename vector_type<ComputeTypeA,
xdlops_gemm.K1PerXdlops /
APackedSize>::type;
using mfma_input_type_b =
typename vector_type<ComputeTypeB,
xdlops_gemm.K1PerXdlops /
BPackedSize>::type;
using mfma_scale_input_type_a =
typename vector_type<AScaleDataType,
a_scale_thread_vec_size>::type;
using mfma_scale_input_type_b =
typename vector_type<BScaleDataType,
b_scale_thread_vec_size>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
make_tuple(im_major, in_major, im_minor, in_minor, 0));
// MFMA accumulation A * Gate
xdlops_gemm.template Run<ik_minor * MXdlPack + im_minor,
ik_minor * NXdlPack + in_minor>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec.template AsType<mfma_scale_input_type_a>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec.template AsType<mfma_scale_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
// MFMA accumulation A * Up
xdlops_gemm.template Run<ik_minor * MXdlPack + im_minor,
ik_minor * NXdlPack + in_minor>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec.template AsType<mfma_scale_input_type_a>(),
b_thread_vec_up.template AsType<mfma_input_type_b>(),
b_scale_thread_vec_up
.template AsType<mfma_scale_input_type_b>(),
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
});
using mfma_input_type_a =
typename vector_type<ComputeTypeA,
xdlops_gemm.K1PerXdlops / APackedSize>::type;
using mfma_input_type_b =
typename vector_type<ComputeTypeB,
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
using mfma_scale_input_type_a =
typename vector_type<AScaleDataType, a_scale_thread_vec_size>::type;
using mfma_scale_input_type_b =
typename vector_type<BScaleDataType, b_scale_thread_vec_size>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
make_tuple(im_major, in_major, im_minor, in_minor, 0));
// MFMA accumulation A * Gate
xdlops_gemm.template Run<ik_minor * MXdlPack + im_minor,
ik_minor * NXdlPack + in_minor>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec.template AsType<mfma_scale_input_type_a>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec.template AsType<mfma_scale_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
// MFMA accumulation A * Up
xdlops_gemm.template Run<ik_minor * MXdlPack + im_minor,
ik_minor * NXdlPack + in_minor>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec.template AsType<mfma_scale_input_type_a>(),
b_thread_vec_up.template AsType<mfma_input_type_b>(),
b_scale_thread_vec_up.template AsType<mfma_scale_input_type_b>(),
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
});
if constexpr(m0.value == SwitchM)
@@ -798,91 +788,91 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_gufusion_v3<
static_for<0, MRepeat, 1>{}([&](auto m0) {
constexpr auto im_major = m0 / MXdlPack;
constexpr auto im_minor = m0 % MXdlPack;
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_ford<Sequence<KRepeat, NRepeat>>{}([&](auto kn) {
constexpr auto k0 = Number<kn[Number<0>{}]>{};
constexpr auto n0 = Number<kn[Number<1>{}]>{};
constexpr auto ik_major = k0 / KXdlPack;
constexpr auto ik_minor = k0 % KXdlPack;
static_for<0, NRepeat, 1>{}([&](auto n0) {
constexpr auto in_major = n0 / NXdlPack;
constexpr auto in_minor = n0 % NXdlPack;
constexpr auto in_major = n0 / NXdlPack;
constexpr auto in_minor = n0 % NXdlPack;
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0));
constexpr index_t b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0));
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0));
constexpr index_t b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0));
static_assert(0 < ScalesPerXdlopsRunPerThread,
"Must have at least one scale per Xdlops "
"per Thread.");
static_assert(0 < ScalesPerXdlopsRunPerThread,
"Must have at least one scale per Xdlops "
"per Thread.");
vector_type<AScaleDataType, a_scale_thread_vec_size> a_scale_thread_vec;
vector_type<BScaleDataType, b_scale_thread_vec_size> b_scale_thread_vec;
vector_type<BScaleDataType, b_scale_thread_vec_size> b_scale_thread_vec_up;
vector_type<AScaleDataType, a_scale_thread_vec_size> a_scale_thread_vec;
vector_type<BScaleDataType, b_scale_thread_vec_size> b_scale_thread_vec;
vector_type<BScaleDataType, b_scale_thread_vec_size> b_scale_thread_vec_up;
// Pack scale_thread_buf into scale_thread_vec
static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) {
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
a_scale_thread_bufs(I0)[Number<a_scale_offset + s>{}];
});
// B Gate scale
static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) {
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
b_scale_thread_bufs(I0)[Number<b_scale_offset + s>{}];
});
// B Up scale
static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) {
b_scale_thread_vec_up.template AsType<BScaleDataType>()(s) =
b_scale_thread_bufs_up(I0)[Number<b_scale_offset + s>{}];
});
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec_up;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(I0, I0, im_minor, k0, ik))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(in_major, I0, in_minor, k0, ik))>{}];
b_thread_vec_up.template AsType<ComputeTypeB>()(ik) =
b_thread_bufs_up[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(in_major, I0, in_minor, k0, ik))>{}];
});
using mfma_input_type_a =
typename vector_type<ComputeTypeA,
xdlops_gemm.K1PerXdlops / APackedSize>::type;
using mfma_input_type_b =
typename vector_type<ComputeTypeB,
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
using mfma_scale_input_type_a =
typename vector_type<AScaleDataType, a_scale_thread_vec_size>::type;
using mfma_scale_input_type_b =
typename vector_type<BScaleDataType, b_scale_thread_vec_size>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
make_tuple(im_major, in_major, im_minor, in_minor, 0));
// MFMA accumulation A * Gate
xdlops_gemm.template Run<ik_minor * MXdlPack + im_minor,
ik_minor * NXdlPack + in_minor>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec.template AsType<mfma_scale_input_type_a>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec.template AsType<mfma_scale_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
// MFMA accumulation A * Gate
xdlops_gemm.template Run<ik_minor * MXdlPack + im_minor,
ik_minor * NXdlPack + in_minor>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec.template AsType<mfma_scale_input_type_a>(),
b_thread_vec_up.template AsType<mfma_input_type_b>(),
b_scale_thread_vec_up.template AsType<mfma_scale_input_type_b>(),
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
// Pack scale_thread_buf into scale_thread_vec
static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) {
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
a_scale_thread_bufs(I0)[Number<a_scale_offset + s>{}];
});
// B Gate scale
static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) {
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
b_scale_thread_bufs(I0)[Number<b_scale_offset + s>{}];
});
// B Up scale
static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) {
b_scale_thread_vec_up.template AsType<BScaleDataType>()(s) =
b_scale_thread_bufs_up(I0)[Number<b_scale_offset + s>{}];
});
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec_up;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(I0, I0, im_minor, k0, ik))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(in_major, I0, in_minor, k0, ik))>{}];
b_thread_vec_up.template AsType<ComputeTypeB>()(ik) =
b_thread_bufs_up[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(in_major, I0, in_minor, k0, ik))>{}];
});
using mfma_input_type_a =
typename vector_type<ComputeTypeA,
xdlops_gemm.K1PerXdlops / APackedSize>::type;
using mfma_input_type_b =
typename vector_type<ComputeTypeB,
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
using mfma_scale_input_type_a =
typename vector_type<AScaleDataType, a_scale_thread_vec_size>::type;
using mfma_scale_input_type_b =
typename vector_type<BScaleDataType, b_scale_thread_vec_size>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
make_tuple(im_major, in_major, im_minor, in_minor, 0));
// MFMA accumulation A * Gate
xdlops_gemm.template Run<ik_minor * MXdlPack + im_minor,
ik_minor * NXdlPack + in_minor>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec.template AsType<mfma_scale_input_type_a>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec.template AsType<mfma_scale_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
// MFMA accumulation A * Gate
xdlops_gemm.template Run<ik_minor * MXdlPack + im_minor,
ik_minor * NXdlPack + in_minor>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec.template AsType<mfma_scale_input_type_a>(),
b_thread_vec_up.template AsType<mfma_input_type_b>(),
b_scale_thread_vec_up.template AsType<mfma_scale_input_type_b>(),
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
});
if constexpr(m0.value == SwitchM)
{
@@ -920,89 +910,89 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_gufusion_v3<
static_for<0, MRepeat, 1>{}([&](auto m0) {
constexpr auto im_major = m0 / MXdlPack;
constexpr auto im_minor = m0 % MXdlPack;
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_ford<Sequence<KRepeat, NRepeat>>{}([&](auto kn) {
constexpr auto k0 = Number<kn[Number<0>{}]>{};
constexpr auto n0 = Number<kn[Number<1>{}]>{};
constexpr auto ik_major = k0 / KXdlPack;
constexpr auto ik_minor = k0 % KXdlPack;
static_for<0, NRepeat, 1>{}([&](auto n0) {
constexpr auto in_major = n0 / NXdlPack;
constexpr auto in_minor = n0 % NXdlPack;
constexpr auto in_major = n0 / NXdlPack;
constexpr auto in_minor = n0 % NXdlPack;
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0));
constexpr index_t b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0));
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0));
constexpr index_t b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0));
static_assert(0 < ScalesPerXdlopsRunPerThread,
"Must have at least one scale per Xdlops "
"per Thread.");
static_assert(0 < ScalesPerXdlopsRunPerThread,
"Must have at least one scale per Xdlops "
"per Thread.");
vector_type<AScaleDataType, a_scale_thread_vec_size> a_scale_thread_vec;
vector_type<BScaleDataType, b_scale_thread_vec_size> b_scale_thread_vec;
vector_type<BScaleDataType, b_scale_thread_vec_size> b_scale_thread_vec_up;
vector_type<AScaleDataType, a_scale_thread_vec_size> a_scale_thread_vec;
vector_type<BScaleDataType, b_scale_thread_vec_size> b_scale_thread_vec;
vector_type<BScaleDataType, b_scale_thread_vec_size> b_scale_thread_vec_up;
// Pack scale_thread_buf into scale_thread_vec
static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) {
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
a_scale_thread_bufs(I1)[Number<a_scale_offset + s>{}];
});
static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) {
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
b_scale_thread_bufs(I1)[Number<b_scale_offset + s>{}];
});
static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) {
b_scale_thread_vec_up.template AsType<BScaleDataType>()(s) =
b_scale_thread_bufs_up(I1)[Number<b_scale_offset + s>{}];
});
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec_up;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(I0, I0, im_minor, k0, ik))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_bufs[I1][Number<b_thread_desc_.CalculateOffset(
make_tuple(in_major, I0, in_minor, k0, ik))>{}];
b_thread_vec_up.template AsType<ComputeTypeB>()(ik) =
b_thread_bufs_up[I1][Number<b_thread_desc_.CalculateOffset(
make_tuple(in_major, I0, in_minor, k0, ik))>{}];
});
using mfma_input_type_a =
typename vector_type<ComputeTypeA,
xdlops_gemm.K1PerXdlops / APackedSize>::type;
using mfma_input_type_b =
typename vector_type<ComputeTypeB,
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
using mfma_scale_input_type_a =
typename vector_type<AScaleDataType, a_scale_thread_vec_size>::type;
using mfma_scale_input_type_b =
typename vector_type<BScaleDataType, b_scale_thread_vec_size>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
make_tuple(im_major, in_major, im_minor, in_minor, 0));
// MFMA accumulation A * Gate
xdlops_gemm.template Run<ik_minor * MXdlPack + im_minor,
ik_minor * NXdlPack + in_minor>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec.template AsType<mfma_scale_input_type_a>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec.template AsType<mfma_scale_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
// MFMA accumulation A * Up
xdlops_gemm.template Run<ik_minor * MXdlPack + im_minor,
ik_minor * NXdlPack + in_minor>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec.template AsType<mfma_scale_input_type_a>(),
b_thread_vec_up.template AsType<mfma_input_type_b>(),
b_scale_thread_vec_up.template AsType<mfma_scale_input_type_b>(),
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
// Pack scale_thread_buf into scale_thread_vec
static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) {
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
a_scale_thread_bufs(I1)[Number<a_scale_offset + s>{}];
});
static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) {
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
b_scale_thread_bufs(I1)[Number<b_scale_offset + s>{}];
});
static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) {
b_scale_thread_vec_up.template AsType<BScaleDataType>()(s) =
b_scale_thread_bufs_up(I1)[Number<b_scale_offset + s>{}];
});
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec_up;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(I0, I0, im_minor, k0, ik))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_bufs[I1][Number<b_thread_desc_.CalculateOffset(
make_tuple(in_major, I0, in_minor, k0, ik))>{}];
b_thread_vec_up.template AsType<ComputeTypeB>()(ik) =
b_thread_bufs_up[I1][Number<b_thread_desc_.CalculateOffset(
make_tuple(in_major, I0, in_minor, k0, ik))>{}];
});
using mfma_input_type_a =
typename vector_type<ComputeTypeA,
xdlops_gemm.K1PerXdlops / APackedSize>::type;
using mfma_input_type_b =
typename vector_type<ComputeTypeB,
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
using mfma_scale_input_type_a =
typename vector_type<AScaleDataType, a_scale_thread_vec_size>::type;
using mfma_scale_input_type_b =
typename vector_type<BScaleDataType, b_scale_thread_vec_size>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
make_tuple(im_major, in_major, im_minor, in_minor, 0));
// MFMA accumulation A * Gate
xdlops_gemm.template Run<ik_minor * MXdlPack + im_minor,
ik_minor * NXdlPack + in_minor>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec.template AsType<mfma_scale_input_type_a>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec.template AsType<mfma_scale_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
// MFMA accumulation A * Up
xdlops_gemm.template Run<ik_minor * MXdlPack + im_minor,
ik_minor * NXdlPack + in_minor>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec.template AsType<mfma_scale_input_type_a>(),
b_thread_vec_up.template AsType<mfma_input_type_b>(),
b_scale_thread_vec_up.template AsType<mfma_scale_input_type_b>(),
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
});
if constexpr(m0.value < (MRepeat - LocalPrefetchStages))
{
@@ -1040,91 +1030,91 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_gufusion_v3<
static_for<0, MRepeat, 1>{}([&](auto m0) {
constexpr auto im_major = m0 / MXdlPack;
constexpr auto im_minor = m0 % MXdlPack;
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_ford<Sequence<KRepeat, NRepeat>>{}([&](auto kn) {
constexpr auto k0 = Number<kn[Number<0>{}]>{};
constexpr auto n0 = Number<kn[Number<1>{}]>{};
constexpr auto ik_major = k0 / KXdlPack;
constexpr auto ik_minor = k0 % KXdlPack;
static_for<0, NRepeat, 1>{}([&](auto n0) {
constexpr auto in_major = n0 / NXdlPack;
constexpr auto in_minor = n0 % NXdlPack;
constexpr auto in_major = n0 / NXdlPack;
constexpr auto in_minor = n0 % NXdlPack;
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0));
constexpr index_t b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0));
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0));
constexpr index_t b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0));
static_assert(0 < ScalesPerXdlopsRunPerThread,
"Must have at least one scale per Xdlops "
"per Thread.");
static_assert(0 < ScalesPerXdlopsRunPerThread,
"Must have at least one scale per Xdlops "
"per Thread.");
vector_type<AScaleDataType, a_scale_thread_vec_size> a_scale_thread_vec;
vector_type<BScaleDataType, b_scale_thread_vec_size> b_scale_thread_vec;
vector_type<BScaleDataType, b_scale_thread_vec_size> b_scale_thread_vec_up;
vector_type<AScaleDataType, a_scale_thread_vec_size> a_scale_thread_vec;
vector_type<BScaleDataType, b_scale_thread_vec_size> b_scale_thread_vec;
vector_type<BScaleDataType, b_scale_thread_vec_size> b_scale_thread_vec_up;
// Pack scale_thread_buf into scale_thread_vec
static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) {
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
a_scale_thread_bufs(I0)[Number<a_scale_offset + s>{}];
});
// B Gate scale
static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) {
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
b_scale_thread_bufs(I0)[Number<b_scale_offset + s>{}];
});
// B Up scale
static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) {
b_scale_thread_vec_up.template AsType<BScaleDataType>()(s) =
b_scale_thread_bufs_up(I0)[Number<b_scale_offset + s>{}];
});
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec_up;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(I0, I0, im_minor, k0, ik))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(in_major, I0, in_minor, k0, ik))>{}];
b_thread_vec_up.template AsType<ComputeTypeB>()(ik) =
b_thread_bufs_up[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(in_major, I0, in_minor, k0, ik))>{}];
});
using mfma_input_type_a =
typename vector_type<ComputeTypeA,
xdlops_gemm.K1PerXdlops / APackedSize>::type;
using mfma_input_type_b =
typename vector_type<ComputeTypeB,
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
using mfma_scale_input_type_a =
typename vector_type<AScaleDataType, a_scale_thread_vec_size>::type;
using mfma_scale_input_type_b =
typename vector_type<BScaleDataType, b_scale_thread_vec_size>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
make_tuple(im_major, in_major, im_minor, in_minor, 0));
// MFMA accumulation A * Gate
xdlops_gemm.template Run<ik_minor * MXdlPack + im_minor,
ik_minor * NXdlPack + in_minor>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec.template AsType<mfma_scale_input_type_a>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec.template AsType<mfma_scale_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
// MFMA accumulation A * up
xdlops_gemm.template Run<ik_minor * MXdlPack + im_minor,
ik_minor * NXdlPack + in_minor>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec.template AsType<mfma_scale_input_type_a>(),
b_thread_vec_up.template AsType<mfma_input_type_b>(),
b_scale_thread_vec_up.template AsType<mfma_scale_input_type_b>(),
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
// Pack scale_thread_buf into scale_thread_vec
static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) {
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
a_scale_thread_bufs(I0)[Number<a_scale_offset + s>{}];
});
// B Gate scale
static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) {
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
b_scale_thread_bufs(I0)[Number<b_scale_offset + s>{}];
});
// B Up scale
static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) {
b_scale_thread_vec_up.template AsType<BScaleDataType>()(s) =
b_scale_thread_bufs_up(I0)[Number<b_scale_offset + s>{}];
});
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec_up;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(I0, I0, im_minor, k0, ik))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(in_major, I0, in_minor, k0, ik))>{}];
b_thread_vec_up.template AsType<ComputeTypeB>()(ik) =
b_thread_bufs_up[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(in_major, I0, in_minor, k0, ik))>{}];
});
using mfma_input_type_a =
typename vector_type<ComputeTypeA,
xdlops_gemm.K1PerXdlops / APackedSize>::type;
using mfma_input_type_b =
typename vector_type<ComputeTypeB,
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
using mfma_scale_input_type_a =
typename vector_type<AScaleDataType, a_scale_thread_vec_size>::type;
using mfma_scale_input_type_b =
typename vector_type<BScaleDataType, b_scale_thread_vec_size>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
make_tuple(im_major, in_major, im_minor, in_minor, 0));
// MFMA accumulation A * Gate
xdlops_gemm.template Run<ik_minor * MXdlPack + im_minor,
ik_minor * NXdlPack + in_minor>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec.template AsType<mfma_scale_input_type_a>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec.template AsType<mfma_scale_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
// MFMA accumulation A * up
xdlops_gemm.template Run<ik_minor * MXdlPack + im_minor,
ik_minor * NXdlPack + in_minor>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec.template AsType<mfma_scale_input_type_a>(),
b_thread_vec_up.template AsType<mfma_input_type_b>(),
b_scale_thread_vec_up.template AsType<mfma_scale_input_type_b>(),
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
});
if constexpr(m0.value < (MRepeat - LocalPrefetchStages))
{

View File

@@ -357,31 +357,31 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v1<BlockGemmPipelineSched
// Local prefetch 1, sync the async load
__builtin_amdgcn_s_waitcnt(async_vmcnt_encoding);
block_sync_lds();
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k) {
static_ford<
Sequence<MRepeat, KRepeat, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk)>>{}(
[&](auto mkc) {
constexpr auto m0 = Number<mkc[Number<0>{}]>{};
constexpr auto k = Number<mkc[Number<1>{}]>{};
constexpr auto chunk = Number<mkc[Number<2>{}]>{};
constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize *
(APackedSize * KPack / xdlops_gemm.K1PerXdlops);
static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}(
[&](auto chunk) {
constexpr auto a_k_step_chunk =
k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k,
make_tuple(Number<m0 / MXdlPack>{},
I0,
Number<m0 % MXdlPack>{},
I0,
Number<a_k_step_chunk>{}),
a_block_buf,
a_thread_desc_,
make_tuple(Number<m0 / MXdlPack>{},
I0,
Number<m0 % MXdlPack>{},
k,
Number<chunk * KThreadChunk>{}),
a_thread_buf);
});
constexpr auto a_k_step_chunk =
k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k,
make_tuple(Number<m0 / MXdlPack>{},
I0,
Number<m0 % MXdlPack>{},
I0,
Number<a_k_step_chunk>{}),
a_block_buf,
a_thread_desc_,
make_tuple(Number<m0 / MXdlPack>{},
I0,
Number<m0 % MXdlPack>{},
k,
Number<chunk * KThreadChunk>{}),
a_thread_buf);
});
});
// Initialize C
c_thread_buf.Clear();
@@ -448,118 +448,107 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v1<BlockGemmPipelineSched
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_ford<Sequence<MRepeat, KRepeat, NRepeat>>{}([&](auto mkn) {
constexpr auto m0 = Number<mkn[Number<0>{}]>{};
constexpr auto k0 = Number<mkn[Number<1>{}]>{};
constexpr auto n0 = Number<mkn[Number<2>{}]>{};
constexpr auto im_major = m0 / MXdlPack;
constexpr auto im_minor = m0 % MXdlPack;
static_for<0, KRepeat, 1>{}([&](auto k0) {
constexpr auto ik_major = k0 / KXdlPack;
constexpr auto ik_minor = k0 % KXdlPack;
static_for<0, NRepeat, 1>{}([&](auto n0) {
constexpr auto in_major = n0 / NXdlPack;
constexpr auto in_minor = n0 % NXdlPack;
constexpr auto ik_major = k0 / KXdlPack;
constexpr auto ik_minor = k0 % KXdlPack;
constexpr auto in_major = n0 / NXdlPack;
constexpr auto in_minor = n0 % NXdlPack;
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(
make_tuple(im_major, ik_major, I0));
constexpr index_t b_scale_offset =
b_scale_thread_desc.CalculateOffset(
make_tuple(in_major, ik_major, I0));
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0));
constexpr index_t b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0));
static_assert(0 < ScalesPerXdlopsRunPerThread,
"Must have at least one scale per Xdlops "
"per Thread.");
static_assert(0 < ScalesPerXdlopsRunPerThread,
"Must have at least one scale per Xdlops "
"per Thread.");
vector_type<AScaleDataType, a_scale_thread_vec_size>
a_scale_thread_vec;
vector_type<BScaleDataType, b_scale_thread_vec_size>
b_scale_thread_vec;
vector_type<AScaleDataType, a_scale_thread_vec_size> a_scale_thread_vec;
vector_type<BScaleDataType, b_scale_thread_vec_size> b_scale_thread_vec;
// Pack scale_thread_buf into scale_thread_vec
static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) {
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
a_scale_thread_bufs(
scale_comp_buf)[Number<a_scale_offset + s>{}];
});
static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) {
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
b_scale_thread_bufs(
scale_comp_buf)[Number<b_scale_offset + s>{}];
});
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(im_major, I0, im_minor, k0, ik))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(ik) = b_thread_bufs
[scale_comp_buf][Number<b_thread_desc_.CalculateOffset(
make_tuple(in_major, I0, in_minor, k0, ik))>{}];
});
using mfma_input_type_a =
typename vector_type<ComputeTypeA,
xdlops_gemm.K1PerXdlops /
APackedSize>::type;
using mfma_input_type_b =
typename vector_type<ComputeTypeB,
xdlops_gemm.K1PerXdlops /
BPackedSize>::type;
using mfma_scale_input_type_a =
typename vector_type<AScaleDataType,
a_scale_thread_vec_size>::type;
using mfma_scale_input_type_b =
typename vector_type<BScaleDataType,
b_scale_thread_vec_size>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
make_tuple(im_major, in_major, im_minor, in_minor, 0));
// MFMA accumulation
xdlops_gemm.template Run<ik_minor * MXdlPack + im_minor,
ik_minor * NXdlPack + in_minor>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec.template AsType<mfma_scale_input_type_a>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec.template AsType<mfma_scale_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
// Pack scale_thread_buf into scale_thread_vec
static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) {
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
a_scale_thread_bufs(scale_comp_buf)[Number<a_scale_offset + s>{}];
});
static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) {
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
b_scale_thread_bufs(scale_comp_buf)[Number<b_scale_offset + s>{}];
});
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(im_major, I0, im_minor, k0, ik))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_bufs[scale_comp_buf][Number<b_thread_desc_.CalculateOffset(
make_tuple(in_major, I0, in_minor, k0, ik))>{}];
});
using mfma_input_type_a =
typename vector_type<ComputeTypeA,
xdlops_gemm.K1PerXdlops / APackedSize>::type;
using mfma_input_type_b =
typename vector_type<ComputeTypeB,
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
using mfma_scale_input_type_a =
typename vector_type<AScaleDataType, a_scale_thread_vec_size>::type;
using mfma_scale_input_type_b =
typename vector_type<BScaleDataType, b_scale_thread_vec_size>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
make_tuple(im_major, in_major, im_minor, in_minor, 0));
// MFMA accumulation
xdlops_gemm.template Run<ik_minor * MXdlPack + im_minor,
ik_minor * NXdlPack + in_minor>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec.template AsType<mfma_scale_input_type_a>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec.template AsType<mfma_scale_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
block_sync_lds();
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k) {
static_ford<Sequence<MRepeat,
KRepeat,
xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk)>>{}(
[&](auto mkc) {
constexpr auto m0 = Number<mkc[Number<0>{}]>{};
constexpr auto k = Number<mkc[Number<1>{}]>{};
constexpr auto chunk = Number<mkc[Number<2>{}]>{};
constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize *
(APackedSize * KPack / xdlops_gemm.K1PerXdlops);
static_for<0,
xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk),
1>{}([&](auto chunk) {
constexpr auto a_k_step_chunk =
k_step +
chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k,
make_tuple(Number<m0 / MXdlPack>{},
I0,
Number<m0 % MXdlPack>{},
I0,
Number<a_k_step_chunk>{}),
a_block_buf,
a_thread_desc_,
make_tuple(Number<m0 / MXdlPack>{},
I0,
Number<m0 % MXdlPack>{},
k,
Number<chunk * KThreadChunk>{}),
a_thread_buf);
});
constexpr auto a_k_step_chunk =
k_step +
chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k,
make_tuple(Number<m0 / MXdlPack>{},
I0,
Number<m0 % MXdlPack>{},
I0,
Number<a_k_step_chunk>{}),
a_block_buf,
a_thread_desc_,
make_tuple(Number<m0 / MXdlPack>{},
I0,
Number<m0 % MXdlPack>{},
k,
Number<chunk * KThreadChunk>{}),
a_thread_buf);
});
});
HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0);
};
@@ -611,257 +600,246 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v1<BlockGemmPipelineSched
b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0));
});
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_ford<Sequence<MRepeat, KRepeat, NRepeat>>{}([&](auto mkn) {
constexpr auto m0 = Number<mkn[Number<0>{}]>{};
constexpr auto k0 = Number<mkn[Number<1>{}]>{};
constexpr auto n0 = Number<mkn[Number<2>{}]>{};
constexpr auto im_major = m0 / MXdlPack;
constexpr auto im_minor = m0 % MXdlPack;
static_for<0, KRepeat, 1>{}([&](auto k0) {
constexpr auto ik_major = k0 / KXdlPack;
constexpr auto ik_minor = k0 % KXdlPack;
static_for<0, NRepeat, 1>{}([&](auto n0) {
constexpr auto in_major = n0 / NXdlPack;
constexpr auto in_minor = n0 % NXdlPack;
constexpr auto ik_major = k0 / KXdlPack;
constexpr auto ik_minor = k0 % KXdlPack;
constexpr auto in_major = n0 / NXdlPack;
constexpr auto in_minor = n0 % NXdlPack;
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0));
constexpr index_t b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0));
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0));
constexpr index_t b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0));
static_assert(0 < ScalesPerXdlopsRunPerThread,
"Must have at least one scale per Xdlops "
"per Thread.");
static_assert(0 < ScalesPerXdlopsRunPerThread,
"Must have at least one scale per Xdlops "
"per Thread.");
vector_type<AScaleDataType, a_scale_thread_vec_size> a_scale_thread_vec;
vector_type<BScaleDataType, b_scale_thread_vec_size> b_scale_thread_vec;
vector_type<AScaleDataType, a_scale_thread_vec_size> a_scale_thread_vec;
vector_type<BScaleDataType, b_scale_thread_vec_size> b_scale_thread_vec;
// Pack scale_thread_buf into scale_thread_vec
static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) {
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
a_scale_thread_bufs(I0)[Number<a_scale_offset + s>{}];
});
static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) {
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
b_scale_thread_bufs(I0)[Number<b_scale_offset + s>{}];
});
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(im_major, I0, im_minor, k0, ik))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(in_major, I0, in_minor, k0, ik))>{}];
});
using mfma_input_type_a =
typename vector_type<ComputeTypeA,
xdlops_gemm.K1PerXdlops / APackedSize>::type;
using mfma_input_type_b =
typename vector_type<ComputeTypeB,
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
using mfma_scale_input_type_a =
typename vector_type<AScaleDataType, a_scale_thread_vec_size>::type;
using mfma_scale_input_type_b =
typename vector_type<BScaleDataType, b_scale_thread_vec_size>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
make_tuple(im_major, in_major, im_minor, in_minor, 0));
// MFMA accumulation
xdlops_gemm.template Run<ik_minor * MXdlPack + im_minor,
ik_minor * NXdlPack + in_minor>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec.template AsType<mfma_scale_input_type_a>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec.template AsType<mfma_scale_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
// Pack scale_thread_buf into scale_thread_vec
static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) {
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
a_scale_thread_bufs(I0)[Number<a_scale_offset + s>{}];
});
// constexpr auto lds_buf = m0.value >= SwitchM ? I1 : I0;
static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) {
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
b_scale_thread_bufs(I0)[Number<b_scale_offset + s>{}];
});
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(im_major, I0, im_minor, k0, ik))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(in_major, I0, in_minor, k0, ik))>{}];
});
using mfma_input_type_a =
typename vector_type<ComputeTypeA, xdlops_gemm.K1PerXdlops / APackedSize>::type;
using mfma_input_type_b =
typename vector_type<ComputeTypeB, xdlops_gemm.K1PerXdlops / BPackedSize>::type;
using mfma_scale_input_type_a =
typename vector_type<AScaleDataType, a_scale_thread_vec_size>::type;
using mfma_scale_input_type_b =
typename vector_type<BScaleDataType, b_scale_thread_vec_size>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
make_tuple(im_major, in_major, im_minor, in_minor, 0));
// MFMA accumulation
xdlops_gemm
.template Run<ik_minor * MXdlPack + im_minor, ik_minor * NXdlPack + in_minor>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec.template AsType<mfma_scale_input_type_a>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec.template AsType<mfma_scale_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
__builtin_amdgcn_s_waitcnt(async_vmcnt_encoding);
block_sync_lds();
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k) {
static_ford<Sequence<MRepeat,
KRepeat,
xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk)>>{}(
[&](auto mkc) {
constexpr auto m0 = Number<mkc[Number<0>{}]>{};
constexpr auto k = Number<mkc[Number<1>{}]>{};
constexpr auto chunk = Number<mkc[Number<2>{}]>{};
constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize *
(APackedSize * KPack / xdlops_gemm.K1PerXdlops);
static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}(
[&](auto chunk) {
constexpr auto a_k_step_chunk =
k_step +
chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k,
make_tuple(Number<m0 / MXdlPack>{},
I0,
Number<m0 % MXdlPack>{},
I0,
Number<a_k_step_chunk>{}),
a_block_buf,
a_thread_desc_,
make_tuple(Number<m0 / MXdlPack>{},
I0,
Number<m0 % MXdlPack>{},
k,
Number<chunk * KThreadChunk>{}),
a_thread_buf);
});
constexpr auto a_k_step_chunk =
k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k,
make_tuple(Number<m0 / MXdlPack>{},
I0,
Number<m0 % MXdlPack>{},
I0,
Number<a_k_step_chunk>{}),
a_block_buf,
a_thread_desc_,
make_tuple(Number<m0 / MXdlPack>{},
I0,
Number<m0 % MXdlPack>{},
k,
Number<chunk * KThreadChunk>{}),
a_thread_buf);
});
});
__builtin_amdgcn_sched_barrier(0);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_ford<Sequence<MRepeat, KRepeat, NRepeat>>{}([&](auto mkn) {
constexpr auto m0 = Number<mkn[Number<0>{}]>{};
constexpr auto k0 = Number<mkn[Number<1>{}]>{};
constexpr auto n0 = Number<mkn[Number<2>{}]>{};
constexpr auto im_major = m0 / MXdlPack;
constexpr auto im_minor = m0 % MXdlPack;
static_for<0, KRepeat, 1>{}([&](auto k0) {
constexpr auto ik_major = k0 / KXdlPack;
constexpr auto ik_minor = k0 % KXdlPack;
static_for<0, NRepeat, 1>{}([&](auto n0) {
constexpr auto in_major = n0 / NXdlPack;
constexpr auto in_minor = n0 % NXdlPack;
constexpr auto ik_major = k0 / KXdlPack;
constexpr auto ik_minor = k0 % KXdlPack;
constexpr auto in_major = n0 / NXdlPack;
constexpr auto in_minor = n0 % NXdlPack;
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0));
constexpr index_t b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0));
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0));
constexpr index_t b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0));
static_assert(0 < ScalesPerXdlopsRunPerThread,
"Must have at least one scale per Xdlops "
"per Thread.");
static_assert(0 < ScalesPerXdlopsRunPerThread,
"Must have at least one scale per Xdlops "
"per Thread.");
vector_type<AScaleDataType, a_scale_thread_vec_size> a_scale_thread_vec;
vector_type<BScaleDataType, b_scale_thread_vec_size> b_scale_thread_vec;
vector_type<AScaleDataType, a_scale_thread_vec_size> a_scale_thread_vec;
vector_type<BScaleDataType, b_scale_thread_vec_size> b_scale_thread_vec;
// Pack scale_thread_buf into scale_thread_vec
static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) {
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
a_scale_thread_bufs(I1)[Number<a_scale_offset + s>{}];
});
static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) {
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
b_scale_thread_bufs(I1)[Number<b_scale_offset + s>{}];
});
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(im_major, I0, im_minor, k0, ik))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_bufs[I1][Number<b_thread_desc_.CalculateOffset(
make_tuple(in_major, I0, in_minor, k0, ik))>{}];
});
using mfma_input_type_a =
typename vector_type<ComputeTypeA,
xdlops_gemm.K1PerXdlops / APackedSize>::type;
using mfma_input_type_b =
typename vector_type<ComputeTypeB,
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
using mfma_scale_input_type_a =
typename vector_type<AScaleDataType, a_scale_thread_vec_size>::type;
using mfma_scale_input_type_b =
typename vector_type<BScaleDataType, b_scale_thread_vec_size>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
make_tuple(im_major, in_major, im_minor, in_minor, 0));
// MFMA accumulation
xdlops_gemm.template Run<ik_minor * MXdlPack + im_minor,
ik_minor * NXdlPack + in_minor>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec.template AsType<mfma_scale_input_type_a>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec.template AsType<mfma_scale_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
// Pack scale_thread_buf into scale_thread_vec
static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) {
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
a_scale_thread_bufs(I1)[Number<a_scale_offset + s>{}];
});
static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) {
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
b_scale_thread_bufs(I1)[Number<b_scale_offset + s>{}];
});
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(im_major, I0, im_minor, k0, ik))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_bufs[I1][Number<b_thread_desc_.CalculateOffset(
make_tuple(in_major, I0, in_minor, k0, ik))>{}];
});
using mfma_input_type_a =
typename vector_type<ComputeTypeA, xdlops_gemm.K1PerXdlops / APackedSize>::type;
using mfma_input_type_b =
typename vector_type<ComputeTypeB, xdlops_gemm.K1PerXdlops / BPackedSize>::type;
using mfma_scale_input_type_a =
typename vector_type<AScaleDataType, a_scale_thread_vec_size>::type;
using mfma_scale_input_type_b =
typename vector_type<BScaleDataType, b_scale_thread_vec_size>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
make_tuple(im_major, in_major, im_minor, in_minor, 0));
// MFMA accumulation
xdlops_gemm
.template Run<ik_minor * MXdlPack + im_minor, ik_minor * NXdlPack + in_minor>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec.template AsType<mfma_scale_input_type_a>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec.template AsType<mfma_scale_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
}
else if constexpr(TailNum == TailNumber::Odd)
{
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_ford<Sequence<MRepeat, KRepeat, NRepeat>>{}([&](auto mkn) {
constexpr auto m0 = Number<mkn[Number<0>{}]>{};
constexpr auto k0 = Number<mkn[Number<1>{}]>{};
constexpr auto n0 = Number<mkn[Number<2>{}]>{};
constexpr auto im_major = m0 / MXdlPack;
constexpr auto im_minor = m0 % MXdlPack;
static_for<0, KRepeat, 1>{}([&](auto k0) {
constexpr auto ik_major = k0 / KXdlPack;
constexpr auto ik_minor = k0 % KXdlPack;
static_for<0, NRepeat, 1>{}([&](auto n0) {
constexpr auto in_major = n0 / NXdlPack;
constexpr auto in_minor = n0 % NXdlPack;
constexpr auto ik_major = k0 / KXdlPack;
constexpr auto ik_minor = k0 % KXdlPack;
constexpr auto in_major = n0 / NXdlPack;
constexpr auto in_minor = n0 % NXdlPack;
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0));
constexpr index_t b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0));
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0));
constexpr index_t b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0));
static_assert(0 < ScalesPerXdlopsRunPerThread,
"Must have at least one scale per Xdlops "
"per Thread.");
static_assert(0 < ScalesPerXdlopsRunPerThread,
"Must have at least one scale per Xdlops "
"per Thread.");
vector_type<AScaleDataType, a_scale_thread_vec_size> a_scale_thread_vec;
vector_type<BScaleDataType, b_scale_thread_vec_size> b_scale_thread_vec;
vector_type<AScaleDataType, a_scale_thread_vec_size> a_scale_thread_vec;
vector_type<BScaleDataType, b_scale_thread_vec_size> b_scale_thread_vec;
// Pack scale_thread_buf into scale_thread_vec
static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) {
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
a_scale_thread_bufs(I0)[Number<a_scale_offset + s>{}];
});
static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) {
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
b_scale_thread_bufs(I0)[Number<b_scale_offset + s>{}];
});
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(im_major, I0, im_minor, k0, ik))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(in_major, I0, in_minor, k0, ik))>{}];
});
using mfma_input_type_a =
typename vector_type<ComputeTypeA,
xdlops_gemm.K1PerXdlops / APackedSize>::type;
using mfma_input_type_b =
typename vector_type<ComputeTypeB,
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
using mfma_scale_input_type_a =
typename vector_type<AScaleDataType, a_scale_thread_vec_size>::type;
using mfma_scale_input_type_b =
typename vector_type<BScaleDataType, b_scale_thread_vec_size>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
make_tuple(im_major, in_major, im_minor, in_minor, 0));
// MFMA accumulation
xdlops_gemm.template Run<ik_minor * MXdlPack + im_minor,
ik_minor * NXdlPack + in_minor>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec.template AsType<mfma_scale_input_type_a>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec.template AsType<mfma_scale_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
// Pack scale_thread_buf into scale_thread_vec
static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) {
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
a_scale_thread_bufs(I0)[Number<a_scale_offset + s>{}];
});
static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) {
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
b_scale_thread_bufs(I0)[Number<b_scale_offset + s>{}];
});
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(im_major, I0, im_minor, k0, ik))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(in_major, I0, in_minor, k0, ik))>{}];
});
using mfma_input_type_a =
typename vector_type<ComputeTypeA, xdlops_gemm.K1PerXdlops / APackedSize>::type;
using mfma_input_type_b =
typename vector_type<ComputeTypeB, xdlops_gemm.K1PerXdlops / BPackedSize>::type;
using mfma_scale_input_type_a =
typename vector_type<AScaleDataType, a_scale_thread_vec_size>::type;
using mfma_scale_input_type_b =
typename vector_type<BScaleDataType, b_scale_thread_vec_size>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
make_tuple(im_major, in_major, im_minor, in_minor, 0));
// MFMA accumulation
xdlops_gemm
.template Run<ik_minor * MXdlPack + im_minor, ik_minor * NXdlPack + in_minor>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec.template AsType<mfma_scale_input_type_a>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec.template AsType<mfma_scale_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
}
}

View File

@@ -261,54 +261,49 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v3<BlockGemmPipelineSched
// Stage 1
// global read more
static_for<0, buffer_load_stages_more, 1>{}([&](auto /*i*/) {
static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
static_ford<Sequence<buffer_load_stages_more, num_mfma_perstage>>{}([&](auto ii) {
constexpr auto imfma = Number<ii[Number<1>{}]>{};
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
if constexpr(imfma % buffer_load_issue_point_interval_more == 0)
{
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
if constexpr(imfma % buffer_load_issue_point_interval_more == 0)
{
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
{
__builtin_amdgcn_sched_group_barrier(
0x100, ds_read_a_mfma_rate, 0); // DS read
}
});
if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
{
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
}
});
// global read less
static_for<0, (num_total_stages - 2 - buffer_load_stages_more), 1>{}([&](auto /*i*/) {
static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
if constexpr(imfma % buffer_load_issue_point_interval_less == 0)
{
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
{
__builtin_amdgcn_sched_group_barrier(
0x100, ds_read_a_mfma_rate, 0); // DS read
}
});
static_ford<Sequence<(num_total_stages - 2 - buffer_load_stages_more),
num_mfma_perstage>>{}([&](auto ii) {
constexpr auto imfma = Number<ii[Number<1>{}]>{};
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
if constexpr(imfma % buffer_load_issue_point_interval_less == 0)
{
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
{
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
}
});
// Stage 2, Sync
// lds synchronization, prefetch next loop local A
static_for<0, num_ds_read_a_prefetch_stages, 1>{}([&](auto /*i*/) {
static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
if constexpr(imfma % buffer_load_issue_point_interval_stage2 == 0)
{
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
{
__builtin_amdgcn_sched_group_barrier(
0x100, ds_read_a_mfma_rate, 0); // DS read
}
});
static_ford<Sequence<num_ds_read_a_prefetch_stages, num_mfma_perstage>>{}([&](auto ii) {
constexpr auto imfma = Number<ii[Number<1>{}]>{};
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
if constexpr(imfma % buffer_load_issue_point_interval_stage2 == 0)
{
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
{
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
}
});
}
else
@@ -536,25 +531,24 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v3<BlockGemmPipelineSched
// Local prefetch 1, sync the async load
__builtin_amdgcn_s_waitcnt(async_vmcnt_encoding);
block_sync_lds();
static_for<0, LocalPrefetchStages, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k) {
constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize *
(APackedSize * KPack / xdlops_gemm.K1PerXdlops);
static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}(
[&](auto chunk) {
constexpr auto a_k_step_chunk =
k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_m3_k,
make_tuple(
I0, I0, Number<m0 % MXdlPack>{}, I0, Number<a_k_step_chunk>{}),
a_block_bufs(I0),
a_thread_desc_,
make_tuple(
I0, I0, Number<m0 % MXdlPack>{}, k, Number<chunk * KThreadChunk>{}),
a_thread_buf);
});
});
static_ford<Sequence<LocalPrefetchStages, KRepeat>>{}([&](auto mk) {
constexpr auto m0 = Number<mk[Number<0>{}]>{};
constexpr auto k = Number<mk[Number<1>{}]>{};
constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize *
(APackedSize * KPack / xdlops_gemm.K1PerXdlops);
static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}(
[&](auto chunk) {
constexpr auto a_k_step_chunk =
k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_m3_k,
make_tuple(I0, I0, Number<m0 % MXdlPack>{}, I0, Number<a_k_step_chunk>{}),
a_block_bufs(I0),
a_thread_desc_,
make_tuple(
I0, I0, Number<m0 % MXdlPack>{}, k, Number<chunk * KThreadChunk>{}),
a_thread_buf);
});
});
// Global prefetch 2
@@ -628,83 +622,76 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v3<BlockGemmPipelineSched
static_for<0, MRepeat, 1>{}([&](auto m0) {
constexpr auto im_major = m0 / MXdlPack;
constexpr auto im_minor = m0 % MXdlPack;
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_ford<Sequence<KRepeat, NRepeat>>{}([&](auto kn) {
constexpr auto k0 = Number<kn[Number<0>{}]>{};
constexpr auto n0 = Number<kn[Number<1>{}]>{};
constexpr auto ik_major = k0 / KXdlPack;
constexpr auto ik_minor = k0 % KXdlPack;
static_for<0, NRepeat, 1>{}([&](auto n0) {
constexpr auto in_major = n0 / NXdlPack;
constexpr auto in_minor = n0 % NXdlPack;
constexpr auto in_major = n0 / NXdlPack;
constexpr auto in_minor = n0 % NXdlPack;
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(
make_tuple(im_major, ik_major, I0));
constexpr index_t b_scale_offset =
b_scale_thread_desc.CalculateOffset(
make_tuple(in_major, ik_major, I0));
constexpr index_t a_scale_offset = a_scale_thread_desc.CalculateOffset(
make_tuple(im_major, ik_major, I0));
constexpr index_t b_scale_offset = b_scale_thread_desc.CalculateOffset(
make_tuple(in_major, ik_major, I0));
static_assert(0 < ScalesPerXdlopsRunPerThread,
"Must have at least one scale per Xdlops "
"per Thread.");
static_assert(0 < ScalesPerXdlopsRunPerThread,
"Must have at least one scale per Xdlops "
"per Thread.");
vector_type<AScaleDataType, a_scale_thread_vec_size>
a_scale_thread_vec;
vector_type<BScaleDataType, b_scale_thread_vec_size>
b_scale_thread_vec;
vector_type<AScaleDataType, a_scale_thread_vec_size> a_scale_thread_vec;
vector_type<BScaleDataType, b_scale_thread_vec_size> b_scale_thread_vec;
// Pack scale_thread_buf into scale_thread_vec
static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) {
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
a_scale_thread_bufs(
scale_comp_buf)[Number<a_scale_offset + s>{}];
});
static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) {
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
b_scale_thread_bufs(
scale_comp_buf)[Number<b_scale_offset + s>{}];
});
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(I0, I0, im_minor, k0, ik))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(ik) = b_thread_bufs
[scale_comp_buf][Number<b_thread_desc_.CalculateOffset(
make_tuple(in_major, I0, in_minor, k0, ik))>{}];
});
using mfma_input_type_a =
typename vector_type<ComputeTypeA,
xdlops_gemm.K1PerXdlops /
APackedSize>::type;
using mfma_input_type_b =
typename vector_type<ComputeTypeB,
xdlops_gemm.K1PerXdlops /
BPackedSize>::type;
using mfma_scale_input_type_a =
typename vector_type<AScaleDataType,
a_scale_thread_vec_size>::type;
using mfma_scale_input_type_b =
typename vector_type<BScaleDataType,
b_scale_thread_vec_size>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
make_tuple(im_major, in_major, im_minor, in_minor, 0));
// MFMA accumulation
xdlops_gemm.template Run<ik_minor * MXdlPack + im_minor,
ik_minor * NXdlPack + in_minor>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec.template AsType<mfma_scale_input_type_a>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec.template AsType<mfma_scale_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
// Pack scale_thread_buf into scale_thread_vec
static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) {
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
a_scale_thread_bufs(
scale_comp_buf)[Number<a_scale_offset + s>{}];
});
static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) {
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
b_scale_thread_bufs(
scale_comp_buf)[Number<b_scale_offset + s>{}];
});
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(I0, I0, im_minor, k0, ik))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_bufs[scale_comp_buf]
[Number<b_thread_desc_.CalculateOffset(make_tuple(
in_major, I0, in_minor, k0, ik))>{}];
});
using mfma_input_type_a =
typename vector_type<ComputeTypeA,
xdlops_gemm.K1PerXdlops / APackedSize>::type;
using mfma_input_type_b =
typename vector_type<ComputeTypeB,
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
using mfma_scale_input_type_a =
typename vector_type<AScaleDataType, a_scale_thread_vec_size>::type;
using mfma_scale_input_type_b =
typename vector_type<BScaleDataType, b_scale_thread_vec_size>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
make_tuple(im_major, in_major, im_minor, in_minor, 0));
// MFMA accumulation
xdlops_gemm.template Run<ik_minor * MXdlPack + im_minor,
ik_minor * NXdlPack + in_minor>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec.template AsType<mfma_scale_input_type_a>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec.template AsType<mfma_scale_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
if constexpr(m0.value == SwitchM)
@@ -802,73 +789,73 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v3<BlockGemmPipelineSched
static_for<0, MRepeat, 1>{}([&](auto m0) {
constexpr auto im_major = m0 / MXdlPack;
constexpr auto im_minor = m0 % MXdlPack;
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_ford<Sequence<KRepeat, NRepeat>>{}([&](auto kn) {
constexpr auto k0 = Number<kn[Number<0>{}]>{};
constexpr auto n0 = Number<kn[Number<1>{}]>{};
constexpr auto ik_major = k0 / KXdlPack;
constexpr auto ik_minor = k0 % KXdlPack;
static_for<0, NRepeat, 1>{}([&](auto n0) {
constexpr auto in_major = n0 / NXdlPack;
constexpr auto in_minor = n0 % NXdlPack;
constexpr auto in_major = n0 / NXdlPack;
constexpr auto in_minor = n0 % NXdlPack;
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0));
constexpr index_t b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0));
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0));
constexpr index_t b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0));
static_assert(0 < ScalesPerXdlopsRunPerThread,
"Must have at least one scale per Xdlops "
"per Thread.");
static_assert(0 < ScalesPerXdlopsRunPerThread,
"Must have at least one scale per Xdlops "
"per Thread.");
vector_type<AScaleDataType, a_scale_thread_vec_size> a_scale_thread_vec;
vector_type<BScaleDataType, b_scale_thread_vec_size> b_scale_thread_vec;
vector_type<AScaleDataType, a_scale_thread_vec_size> a_scale_thread_vec;
vector_type<BScaleDataType, b_scale_thread_vec_size> b_scale_thread_vec;
// Pack scale_thread_buf into scale_thread_vec
static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) {
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
a_scale_thread_bufs(I0)[Number<a_scale_offset + s>{}];
});
static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) {
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
b_scale_thread_bufs(I0)[Number<b_scale_offset + s>{}];
});
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(I0, I0, im_minor, k0, ik))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(in_major, I0, in_minor, k0, ik))>{}];
});
using mfma_input_type_a =
typename vector_type<ComputeTypeA,
xdlops_gemm.K1PerXdlops / APackedSize>::type;
using mfma_input_type_b =
typename vector_type<ComputeTypeB,
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
using mfma_scale_input_type_a =
typename vector_type<AScaleDataType, a_scale_thread_vec_size>::type;
using mfma_scale_input_type_b =
typename vector_type<BScaleDataType, b_scale_thread_vec_size>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
make_tuple(im_major, in_major, im_minor, in_minor, 0));
// MFMA accumulation
xdlops_gemm.template Run<ik_minor * MXdlPack + im_minor,
ik_minor * NXdlPack + in_minor>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec.template AsType<mfma_scale_input_type_a>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec.template AsType<mfma_scale_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
// Pack scale_thread_buf into scale_thread_vec
static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) {
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
a_scale_thread_bufs(I0)[Number<a_scale_offset + s>{}];
});
static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) {
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
b_scale_thread_bufs(I0)[Number<b_scale_offset + s>{}];
});
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(I0, I0, im_minor, k0, ik))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(in_major, I0, in_minor, k0, ik))>{}];
});
using mfma_input_type_a =
typename vector_type<ComputeTypeA,
xdlops_gemm.K1PerXdlops / APackedSize>::type;
using mfma_input_type_b =
typename vector_type<ComputeTypeB,
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
using mfma_scale_input_type_a =
typename vector_type<AScaleDataType, a_scale_thread_vec_size>::type;
using mfma_scale_input_type_b =
typename vector_type<BScaleDataType, b_scale_thread_vec_size>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
make_tuple(im_major, in_major, im_minor, in_minor, 0));
// MFMA accumulation
xdlops_gemm.template Run<ik_minor * MXdlPack + im_minor,
ik_minor * NXdlPack + in_minor>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec.template AsType<mfma_scale_input_type_a>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec.template AsType<mfma_scale_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
if constexpr(m0.value == SwitchM)
{
@@ -906,73 +893,73 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v3<BlockGemmPipelineSched
static_for<0, MRepeat, 1>{}([&](auto m0) {
constexpr auto im_major = m0 / MXdlPack;
constexpr auto im_minor = m0 % MXdlPack;
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_ford<Sequence<KRepeat, NRepeat>>{}([&](auto kn) {
constexpr auto k0 = Number<kn[Number<0>{}]>{};
constexpr auto n0 = Number<kn[Number<1>{}]>{};
constexpr auto ik_major = k0 / KXdlPack;
constexpr auto ik_minor = k0 % KXdlPack;
static_for<0, NRepeat, 1>{}([&](auto n0) {
constexpr auto in_major = n0 / NXdlPack;
constexpr auto in_minor = n0 % NXdlPack;
constexpr auto in_major = n0 / NXdlPack;
constexpr auto in_minor = n0 % NXdlPack;
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0));
constexpr index_t b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0));
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0));
constexpr index_t b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0));
static_assert(0 < ScalesPerXdlopsRunPerThread,
"Must have at least one scale per Xdlops "
"per Thread.");
static_assert(0 < ScalesPerXdlopsRunPerThread,
"Must have at least one scale per Xdlops "
"per Thread.");
vector_type<AScaleDataType, a_scale_thread_vec_size> a_scale_thread_vec;
vector_type<BScaleDataType, b_scale_thread_vec_size> b_scale_thread_vec;
vector_type<AScaleDataType, a_scale_thread_vec_size> a_scale_thread_vec;
vector_type<BScaleDataType, b_scale_thread_vec_size> b_scale_thread_vec;
// Pack scale_thread_buf into scale_thread_vec
static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) {
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
a_scale_thread_bufs(I1)[Number<a_scale_offset + s>{}];
});
static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) {
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
b_scale_thread_bufs(I1)[Number<b_scale_offset + s>{}];
});
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(I0, I0, im_minor, k0, ik))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_bufs[I1][Number<b_thread_desc_.CalculateOffset(
make_tuple(in_major, I0, in_minor, k0, ik))>{}];
});
using mfma_input_type_a =
typename vector_type<ComputeTypeA,
xdlops_gemm.K1PerXdlops / APackedSize>::type;
using mfma_input_type_b =
typename vector_type<ComputeTypeB,
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
using mfma_scale_input_type_a =
typename vector_type<AScaleDataType, a_scale_thread_vec_size>::type;
using mfma_scale_input_type_b =
typename vector_type<BScaleDataType, b_scale_thread_vec_size>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
make_tuple(im_major, in_major, im_minor, in_minor, 0));
// MFMA accumulation
xdlops_gemm.template Run<ik_minor * MXdlPack + im_minor,
ik_minor * NXdlPack + in_minor>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec.template AsType<mfma_scale_input_type_a>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec.template AsType<mfma_scale_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
// Pack scale_thread_buf into scale_thread_vec
static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) {
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
a_scale_thread_bufs(I1)[Number<a_scale_offset + s>{}];
});
static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) {
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
b_scale_thread_bufs(I1)[Number<b_scale_offset + s>{}];
});
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(I0, I0, im_minor, k0, ik))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_bufs[I1][Number<b_thread_desc_.CalculateOffset(
make_tuple(in_major, I0, in_minor, k0, ik))>{}];
});
using mfma_input_type_a =
typename vector_type<ComputeTypeA,
xdlops_gemm.K1PerXdlops / APackedSize>::type;
using mfma_input_type_b =
typename vector_type<ComputeTypeB,
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
using mfma_scale_input_type_a =
typename vector_type<AScaleDataType, a_scale_thread_vec_size>::type;
using mfma_scale_input_type_b =
typename vector_type<BScaleDataType, b_scale_thread_vec_size>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
make_tuple(im_major, in_major, im_minor, in_minor, 0));
// MFMA accumulation
xdlops_gemm.template Run<ik_minor * MXdlPack + im_minor,
ik_minor * NXdlPack + in_minor>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec.template AsType<mfma_scale_input_type_a>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec.template AsType<mfma_scale_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
if constexpr(m0.value < (MRepeat - LocalPrefetchStages))
{
@@ -1010,73 +997,73 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v3<BlockGemmPipelineSched
static_for<0, MRepeat, 1>{}([&](auto m0) {
constexpr auto im_major = m0 / MXdlPack;
constexpr auto im_minor = m0 % MXdlPack;
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_ford<Sequence<KRepeat, NRepeat>>{}([&](auto kn) {
constexpr auto k0 = Number<kn[Number<0>{}]>{};
constexpr auto n0 = Number<kn[Number<1>{}]>{};
constexpr auto ik_major = k0 / KXdlPack;
constexpr auto ik_minor = k0 % KXdlPack;
static_for<0, NRepeat, 1>{}([&](auto n0) {
constexpr auto in_major = n0 / NXdlPack;
constexpr auto in_minor = n0 % NXdlPack;
constexpr auto in_major = n0 / NXdlPack;
constexpr auto in_minor = n0 % NXdlPack;
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0));
constexpr index_t b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0));
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0));
constexpr index_t b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0));
static_assert(0 < ScalesPerXdlopsRunPerThread,
"Must have at least one scale per Xdlops "
"per Thread.");
static_assert(0 < ScalesPerXdlopsRunPerThread,
"Must have at least one scale per Xdlops "
"per Thread.");
vector_type<AScaleDataType, a_scale_thread_vec_size> a_scale_thread_vec;
vector_type<BScaleDataType, b_scale_thread_vec_size> b_scale_thread_vec;
vector_type<AScaleDataType, a_scale_thread_vec_size> a_scale_thread_vec;
vector_type<BScaleDataType, b_scale_thread_vec_size> b_scale_thread_vec;
// Pack scale_thread_buf into scale_thread_vec
static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) {
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
a_scale_thread_bufs(I0)[Number<a_scale_offset + s>{}];
});
static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) {
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
b_scale_thread_bufs(I0)[Number<b_scale_offset + s>{}];
});
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(I0, I0, im_minor, k0, ik))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(in_major, I0, in_minor, k0, ik))>{}];
});
using mfma_input_type_a =
typename vector_type<ComputeTypeA,
xdlops_gemm.K1PerXdlops / APackedSize>::type;
using mfma_input_type_b =
typename vector_type<ComputeTypeB,
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
using mfma_scale_input_type_a =
typename vector_type<AScaleDataType, a_scale_thread_vec_size>::type;
using mfma_scale_input_type_b =
typename vector_type<BScaleDataType, b_scale_thread_vec_size>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
make_tuple(im_major, in_major, im_minor, in_minor, 0));
// MFMA accumulation
xdlops_gemm.template Run<ik_minor * MXdlPack + im_minor,
ik_minor * NXdlPack + in_minor>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec.template AsType<mfma_scale_input_type_a>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec.template AsType<mfma_scale_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
// Pack scale_thread_buf into scale_thread_vec
static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) {
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
a_scale_thread_bufs(I0)[Number<a_scale_offset + s>{}];
});
static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) {
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
b_scale_thread_bufs(I0)[Number<b_scale_offset + s>{}];
});
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(I0, I0, im_minor, k0, ik))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(in_major, I0, in_minor, k0, ik))>{}];
});
using mfma_input_type_a =
typename vector_type<ComputeTypeA,
xdlops_gemm.K1PerXdlops / APackedSize>::type;
using mfma_input_type_b =
typename vector_type<ComputeTypeB,
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
using mfma_scale_input_type_a =
typename vector_type<AScaleDataType, a_scale_thread_vec_size>::type;
using mfma_scale_input_type_b =
typename vector_type<BScaleDataType, b_scale_thread_vec_size>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
make_tuple(im_major, in_major, im_minor, in_minor, 0));
// MFMA accumulation
xdlops_gemm.template Run<ik_minor * MXdlPack + im_minor,
ik_minor * NXdlPack + in_minor>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec.template AsType<mfma_scale_input_type_a>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec.template AsType<mfma_scale_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
if constexpr(m0.value < (MRepeat - LocalPrefetchStages))
{

View File

@@ -280,17 +280,16 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
// Local prefetch A1
block_sync_lds();
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
});
static_ford<Sequence<MRepeat, KRepeat, KGroup>>{}([&](auto mkg) {
constexpr auto m0 = Number<mkg[Number<0>{}]>{};
constexpr auto k0 = Number<mkg[Number<1>{}]>{};
constexpr auto kg0 = Number<mkg[Number<2>{}]>{};
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
// Initialize C
@@ -318,51 +317,46 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_ford<Sequence<MRepeat, NRepeat, KRepeat>>{}([&](auto mnk) {
constexpr auto m0 = Number<mnk[Number<0>{}]>{};
constexpr auto n0 = Number<mnk[Number<1>{}]>{};
constexpr auto k0 = Number<mnk[Number<2>{}]>{};
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[mfma_reg_buf]
[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[mfma_reg_buf][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
block_sync_lds();
// loop prefetch copy
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
});
static_ford<Sequence<MRepeat, KRepeat, KGroup>>{}([&](auto mkg) {
constexpr auto m0 = Number<mkg[Number<0>{}]>{};
constexpr auto k0 = Number<mkg[Number<1>{}]>{};
constexpr auto kg0 = Number<mkg[Number<2>{}]>{};
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
HotLoopScheduler();
@@ -387,79 +381,71 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_ford<Sequence<MRepeat, NRepeat, KRepeat>>{}([&](auto mnk) {
constexpr auto m0 = Number<mnk[Number<0>{}]>{};
constexpr auto n0 = Number<mnk[Number<1>{}]>{};
constexpr auto k0 = Number<mnk[Number<2>{}]>{};
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) = b_thread_bufs
[I0][Number<b_thread_desc_.CalculateOffset(make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
block_sync_lds();
// tail Local Prefetch A1
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
});
static_ford<Sequence<MRepeat, KRepeat, KGroup>>{}([&](auto mkg) {
constexpr auto m0 = Number<mkg[Number<0>{}]>{};
constexpr auto k0 = Number<mkg[Number<1>{}]>{};
constexpr auto kg0 = Number<mkg[Number<2>{}]>{};
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
__builtin_amdgcn_sched_barrier(0);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_ford<Sequence<MRepeat, NRepeat, KRepeat>>{}([&](auto mnk) {
constexpr auto m0 = Number<mnk[Number<0>{}]>{};
constexpr auto n0 = Number<mnk[Number<1>{}]>{};
constexpr auto k0 = Number<mnk[Number<2>{}]>{};
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[I1][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) = b_thread_bufs
[I1][Number<b_thread_desc_.CalculateOffset(make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
// Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle
// latency
@@ -467,32 +453,29 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
}
else if constexpr(TailNum == TailNumber::Odd)
{
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_ford<Sequence<MRepeat, NRepeat, KRepeat>>{}([&](auto mnk) {
constexpr auto m0 = Number<mnk[Number<0>{}]>{};
constexpr auto n0 = Number<mnk[Number<1>{}]>{};
constexpr auto k0 = Number<mnk[Number<2>{}]>{};
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) = b_thread_bufs
[I0][Number<b_thread_desc_.CalculateOffset(make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
}
}

View File

@@ -281,17 +281,16 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
// Local prefetch A1
block_sync_lds();
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf.At(I0),
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_bufs(I0));
});
});
static_ford<Sequence<MRepeat, KRepeat, KGroup>>{}([&](auto mkg) {
constexpr auto m0 = Number<mkg[Number<0>{}]>{};
constexpr auto k0 = Number<mkg[Number<1>{}]>{};
constexpr auto kg0 = Number<mkg[Number<2>{}]>{};
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf.At(I0),
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_bufs(I0));
});
// Local prefill A2
@@ -323,18 +322,17 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// main loop A matrix prefetch
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf.At(local_read_buf),
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_bufs(local_read_buf));
});
});
static_ford<Sequence<MRepeat, KRepeat, KGroup>>{}([&](auto mkg) {
constexpr auto m0 = Number<mkg[Number<0>{}]>{};
constexpr auto k0 = Number<mkg[Number<1>{}]>{};
constexpr auto kg0 = Number<mkg[Number<2>{}]>{};
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf.At(local_read_buf),
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_bufs(local_read_buf));
});
a_blockwise_copy.RunWrite(
@@ -343,36 +341,31 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_ford<Sequence<MRepeat, NRepeat, KRepeat>>{}([&](auto mnk) {
constexpr auto m0 = Number<mnk[Number<0>{}]>{};
constexpr auto n0 = Number<mnk[Number<1>{}]>{};
constexpr auto k0 = Number<mnk[Number<2>{}]>{};
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_bufs[mfma_reg_buf]
[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[mfma_reg_buf]
[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_bufs[mfma_reg_buf][Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[mfma_reg_buf][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
HotLoopScheduler();
@@ -398,48 +391,44 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// tail prefetch A
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf.At(local_read_reg),
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_bufs(local_read_reg));
});
});
static_ford<Sequence<MRepeat, KRepeat, KGroup>>{}([&](auto mkg) {
constexpr auto m0 = Number<mkg[Number<0>{}]>{};
constexpr auto k0 = Number<mkg[Number<1>{}]>{};
constexpr auto kg0 = Number<mkg[Number<2>{}]>{};
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf.At(local_read_reg),
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_bufs(local_read_reg));
});
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(mfma_reg), mfma_reg);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_ford<Sequence<MRepeat, NRepeat, KRepeat>>{}([&](auto mnk) {
constexpr auto m0 = Number<mnk[Number<0>{}]>{};
constexpr auto n0 = Number<mnk[Number<1>{}]>{};
constexpr auto k0 = Number<mnk[Number<2>{}]>{};
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_bufs[mfma_reg][Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[mfma_reg][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_bufs[mfma_reg][Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[mfma_reg][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
HotLoopScheduler();
@@ -455,46 +444,42 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
b_block_origin_idx,
b_thread_bufs(local_read_reg));
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf.At(local_read_reg),
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_bufs(local_read_reg));
});
});
static_ford<Sequence<MRepeat, KRepeat, KGroup>>{}([&](auto mkg) {
constexpr auto m0 = Number<mkg[Number<0>{}]>{};
constexpr auto k0 = Number<mkg[Number<1>{}]>{};
constexpr auto kg0 = Number<mkg[Number<2>{}]>{};
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf.At(local_read_reg),
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_bufs(local_read_reg));
});
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_ford<Sequence<MRepeat, NRepeat, KRepeat>>{}([&](auto mnk) {
constexpr auto m0 = Number<mnk[Number<0>{}]>{};
constexpr auto n0 = Number<mnk[Number<1>{}]>{};
constexpr auto k0 = Number<mnk[Number<2>{}]>{};
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_bufs[mfma_reg][Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[mfma_reg][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_bufs[mfma_reg][Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[mfma_reg][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
HotLoopScheduler();
@@ -502,32 +487,30 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlockGemmPipelineScheduler::I
};
auto CompFunc = [&](auto mfma_reg) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_ford<Sequence<MRepeat, NRepeat, KRepeat>>{}([&](auto mnk) {
constexpr auto m0 = Number<mnk[Number<0>{}]>{};
constexpr auto n0 = Number<mnk[Number<1>{}]>{};
constexpr auto k0 = Number<mnk[Number<2>{}]>{};
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_bufs[mfma_reg][Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[mfma_reg][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_bufs[mfma_reg][Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[mfma_reg][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
};

View File

@@ -258,52 +258,50 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
constexpr auto buffer_load_issue_point_a = num_mfma_perstage >= 3 ? 1 : 0;
// B global read
static_for<0, buffer_load_b_stages, 1>{}([&](auto i) {
static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
__builtin_amdgcn_sched_group_barrier(SCHED_GROUP_MFMA, 1, 0);
static_ford<Sequence<buffer_load_b_stages, num_mfma_perstage>>{}([&](auto ii) {
constexpr auto i = Number<ii[Number<0>{}]>{};
constexpr auto imfma = Number<ii[Number<1>{}]>{};
__builtin_amdgcn_sched_group_barrier(SCHED_GROUP_MFMA, 1, 0);
if constexpr(((i < buffer_load_stages_more) &&
(imfma % buffer_load_issue_point_interval_more == 0)) ||
((i >= buffer_load_stages_more) &&
(imfma % buffer_load_issue_point_interval_less == 0)))
{
__builtin_amdgcn_sched_group_barrier(SCHED_GROUP_VMEM, 1, 0);
}
if constexpr(((i < buffer_load_stages_more) &&
(imfma % buffer_load_issue_point_interval_more == 0)) ||
((i >= buffer_load_stages_more) &&
(imfma % buffer_load_issue_point_interval_less == 0)))
{
__builtin_amdgcn_sched_group_barrier(SCHED_GROUP_VMEM, 1, 0);
}
if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
{
__builtin_amdgcn_sched_group_barrier(
SCHED_GROUP_LDS_READ, ds_read_a_mfma_rate, 0);
}
});
if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
{
__builtin_amdgcn_sched_group_barrier(SCHED_GROUP_LDS_READ, ds_read_a_mfma_rate, 0);
}
});
// A global read + A local write
static_for<0, buffer_load_a_stages, 1>{}([&](auto i) {
static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
__builtin_amdgcn_sched_group_barrier(SCHED_GROUP_MFMA, 1, 0);
if constexpr((((i + buffer_load_b_stages) < buffer_load_stages_more) &&
(imfma % buffer_load_issue_point_interval_more == 0)) ||
(((i + buffer_load_b_stages) >= buffer_load_stages_more) &&
(imfma % buffer_load_issue_point_interval_less == 0)))
{
__builtin_amdgcn_sched_group_barrier(SCHED_GROUP_LDS_WRITE, 1, 0);
}
if constexpr((((i + buffer_load_b_stages) < buffer_load_stages_more) &&
(imfma % buffer_load_issue_point_interval_more ==
buffer_load_issue_point_a)) ||
(((i + buffer_load_b_stages) >= buffer_load_stages_more) &&
(imfma % buffer_load_issue_point_interval_less ==
buffer_load_issue_point_a)))
{
__builtin_amdgcn_sched_group_barrier(SCHED_GROUP_VMEM, 1, 0);
}
if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
{
__builtin_amdgcn_sched_group_barrier(
SCHED_GROUP_LDS_READ, ds_read_a_mfma_rate, 0);
}
});
static_ford<Sequence<buffer_load_a_stages, num_mfma_perstage>>{}([&](auto ii) {
constexpr auto i = Number<ii[Number<0>{}]>{};
constexpr auto imfma = Number<ii[Number<1>{}]>{};
__builtin_amdgcn_sched_group_barrier(SCHED_GROUP_MFMA, 1, 0);
if constexpr((((i + buffer_load_b_stages) < buffer_load_stages_more) &&
(imfma % buffer_load_issue_point_interval_more == 0)) ||
(((i + buffer_load_b_stages) >= buffer_load_stages_more) &&
(imfma % buffer_load_issue_point_interval_less == 0)))
{
__builtin_amdgcn_sched_group_barrier(SCHED_GROUP_LDS_WRITE, 1, 0);
}
if constexpr((((i + buffer_load_b_stages) < buffer_load_stages_more) &&
(imfma % buffer_load_issue_point_interval_more ==
buffer_load_issue_point_a)) ||
(((i + buffer_load_b_stages) >= buffer_load_stages_more) &&
(imfma % buffer_load_issue_point_interval_less ==
buffer_load_issue_point_a)))
{
__builtin_amdgcn_sched_group_barrier(SCHED_GROUP_VMEM, 1, 0);
}
if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
{
__builtin_amdgcn_sched_group_barrier(SCHED_GROUP_LDS_READ, ds_read_a_mfma_rate, 0);
}
});
// lds synchronization, prefetch next loop local A
@@ -379,17 +377,16 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
// Local prefetch A1
block_sync_lds();
static_for<0, DS_READ_A_PREFETCH_STAGES, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf.At(I0),
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
});
static_ford<Sequence<DS_READ_A_PREFETCH_STAGES, KRepeat, KGroup>>{}([&](auto mkk) {
constexpr auto m0 = Number<mkk[Number<0>{}]>{};
constexpr auto k0 = Number<mkk[Number<1>{}]>{};
constexpr auto kg0 = Number<mkk[Number<2>{}]>{};
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf.At(I0),
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
// Initialize C
@@ -416,119 +413,114 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_ford<Sequence<KRepeat, NRepeat>>{}([&](auto kn) {
constexpr auto k0 = Number<kn[Number<0>{}]>{};
constexpr auto n0 = Number<kn[Number<1>{}]>{};
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple((m0 + HotloopLocalBufSwitch * mfma_reg_buf) %
2,
I0,
I0,
k0,
I0,
ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[mfma_reg_buf]
[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple((m0 + HotloopLocalBufSwitch * mfma_reg_buf) % 2,
I0,
I0,
k0,
I0,
ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[mfma_reg_buf]
[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
if constexpr(m0.value == (MRepeat - 2))
{
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<0>{},
I0,
I0,
Number<k0 * KGroup + kg0>{},
I0,
I0),
a_block_buf.At(local_read_buf),
a_thread_desc_,
make_tuple(
Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) %
2>{},
I0,
I0,
k0,
I0,
Number<kg0 * A_K1>{}),
a_thread_buf);
});
static_ford<Sequence<KRepeat, KGroup>>{}([&](auto kk) {
constexpr auto k0 = Number<kk[Number<0>{}]>{};
constexpr auto kg0 = Number<kk[Number<1>{}]>{};
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(
Number<0>{}, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf.At(local_read_buf),
a_thread_desc_,
make_tuple(
Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) %
2>{},
I0,
I0,
k0,
I0,
Number<kg0 * A_K1>{}),
a_thread_buf);
});
}
else if constexpr(m0.value == (MRepeat - 1))
{
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 2) % MRepeat>{},
I0,
I0,
Number<k0 * KGroup + kg0>{},
I0,
I0),
a_block_buf.At(local_read_buf),
a_thread_desc_,
make_tuple(
Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) %
2>{},
I0,
I0,
k0,
I0,
Number<kg0 * A_K1>{}),
a_thread_buf);
});
static_ford<Sequence<KRepeat, KGroup>>{}([&](auto kk) {
constexpr auto k0 = Number<kk[Number<0>{}]>{};
constexpr auto kg0 = Number<kk[Number<1>{}]>{};
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 2) % MRepeat>{},
I0,
I0,
Number<k0 * KGroup + kg0>{},
I0,
I0),
a_block_buf.At(local_read_buf),
a_thread_desc_,
make_tuple(
Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) %
2>{},
I0,
I0,
k0,
I0,
Number<kg0 * A_K1>{}),
a_thread_buf);
});
}
else
{
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 2) % MRepeat>{},
I0,
I0,
Number<k0 * KGroup + kg0>{},
I0,
I0),
a_block_buf.At(mfma_reg_buf),
a_thread_desc_,
make_tuple(
Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) %
2>{},
I0,
I0,
k0,
I0,
Number<kg0 * A_K1>{}),
a_thread_buf);
});
static_ford<Sequence<KRepeat, KGroup>>{}([&](auto kk) {
constexpr auto k0 = Number<kk[Number<0>{}]>{};
constexpr auto kg0 = Number<kk[Number<1>{}]>{};
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 2) % MRepeat>{},
I0,
I0,
Number<k0 * KGroup + kg0>{},
I0,
I0),
a_block_buf.At(mfma_reg_buf),
a_thread_desc_,
make_tuple(
Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) %
2>{},
I0,
I0,
k0,
I0,
Number<kg0 * A_K1>{}),
a_thread_buf);
});
}
});
@@ -552,88 +544,87 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1));
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_ford<Sequence<KRepeat, NRepeat>>{}([&](auto kn) {
constexpr auto k0 = Number<kn[Number<0>{}]>{};
constexpr auto n0 = Number<kn[Number<1>{}]>{};
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0 % 2, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0 % 2, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
if constexpr(m0.value == (MRepeat - 2))
{
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(
Number<0>{}, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf.At(I1),
a_thread_desc_,
make_tuple(
Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
static_ford<Sequence<KRepeat, KGroup>>{}([&](auto kk) {
constexpr auto k0 = Number<kk[Number<0>{}]>{};
constexpr auto kg0 = Number<kk[Number<1>{}]>{};
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<0>{}, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf.At(I1),
a_thread_desc_,
make_tuple(
Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
}
else if constexpr(m0.value == (MRepeat - 1))
{
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 2) % MRepeat>{},
I0,
I0,
Number<k0 * KGroup + kg0>{},
I0,
I0),
a_block_buf.At(I1),
a_thread_desc_,
make_tuple(
Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
static_ford<Sequence<KRepeat, KGroup>>{}([&](auto kk) {
constexpr auto k0 = Number<kk[Number<0>{}]>{};
constexpr auto kg0 = Number<kk[Number<1>{}]>{};
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 2) % MRepeat>{},
I0,
I0,
Number<k0 * KGroup + kg0>{},
I0,
I0),
a_block_buf.At(I1),
a_thread_desc_,
make_tuple(
Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
}
else
{
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 2) % MRepeat>{},
I0,
I0,
Number<k0 * KGroup + kg0>{},
I0,
I0),
a_block_buf.At(I0),
a_thread_desc_,
make_tuple(
Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
static_ford<Sequence<KRepeat, KGroup>>{}([&](auto kk) {
constexpr auto k0 = Number<kk[Number<0>{}]>{};
constexpr auto kg0 = Number<kk[Number<1>{}]>{};
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 2) % MRepeat>{},
I0,
I0,
Number<k0 * KGroup + kg0>{},
I0,
I0),
a_block_buf.At(I0),
a_thread_desc_,
make_tuple(
Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
}
});
@@ -641,50 +632,50 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
HotLoopScheduler();
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_ford<Sequence<KRepeat, NRepeat>>{}([&](auto kn) {
constexpr auto k0 = Number<kn[Number<0>{}]>{};
constexpr auto n0 = Number<kn[Number<1>{}]>{};
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple(
(m0 + HotloopLocalBufSwitch) % 2, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[I1][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple(
(m0 + HotloopLocalBufSwitch) % 2, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[I1][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
if constexpr(m0.value < (MRepeat - 2))
{
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(
Number<m0 + 2>{}, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf.At(I1),
a_thread_desc_,
make_tuple(Number<(m0 + 2 + HotloopLocalBufSwitch) % 2>{},
I0,
I0,
k0,
I0,
Number<kg0 * A_K1>{}),
a_thread_buf);
});
static_ford<Sequence<KRepeat, KGroup>>{}([&](auto kk) {
constexpr auto k0 = Number<kk[Number<0>{}]>{};
constexpr auto kg0 = Number<kk[Number<1>{}]>{};
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(
Number<m0 + 2>{}, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf.At(I1),
a_thread_desc_,
make_tuple(Number<(m0 + 2 + HotloopLocalBufSwitch) % 2>{},
I0,
I0,
k0,
I0,
Number<kg0 * A_K1>{}),
a_thread_buf);
});
}
});
@@ -694,46 +685,46 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
else if constexpr(TailNum == TailNumber::Odd)
{
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_ford<Sequence<KRepeat, NRepeat>>{}([&](auto kn) {
constexpr auto k0 = Number<kn[Number<0>{}]>{};
constexpr auto n0 = Number<kn[Number<1>{}]>{};
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0 % 2, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0 % 2, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
if constexpr(m0.value < (MRepeat - 2))
{
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(
Number<m0 + 2>{}, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf.At(I0),
a_thread_desc_,
make_tuple(
Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
static_ford<Sequence<KRepeat, KGroup>>{}([&](auto kk) {
constexpr auto k0 = Number<kk[Number<0>{}]>{};
constexpr auto kg0 = Number<kk[Number<1>{}]>{};
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(
Number<m0 + 2>{}, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf.At(I0),
a_thread_desc_,
make_tuple(
Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
}
});

View File

@@ -347,22 +347,19 @@ struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v1<BlockGemmPipelineS
constexpr auto num_scale_m_block = CScaleThreadDesc{}.GetLength(Number<1>{});
constexpr auto num_scale_n_block = CScaleThreadDesc{}.GetLength(Number<2>{});
static_for<0, num_scale_m_block, 1>{}([&](auto m0) {
static_for<0, num_scale_n_block, 1>{}([&](auto n0) {
static_for<0, num_scale_k_block, 1>{}([&](auto k0) {
constexpr index_t c_offset =
CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0));
constexpr index_t a_offset =
AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0));
constexpr index_t b_offset =
BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0));
static_ford<Sequence<num_scale_m_block, num_scale_n_block, num_scale_k_block>>{}(
[&](auto mnk) {
constexpr auto m0 = Number<mnk[Number<0>{}]>{};
constexpr auto n0 = Number<mnk[Number<1>{}]>{};
constexpr auto k0 = Number<mnk[Number<2>{}]>{};
constexpr index_t c_offset =
CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0));
constexpr index_t a_offset = AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0));
constexpr index_t b_offset = BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0));
c_scale_thread_buf(Number<c_offset>{}) =
a_scale_thread_buf[Number<a_offset>{}] *
b_scale_thread_buf[Number<b_offset>{}];
});
c_scale_thread_buf(Number<c_offset>{}) =
a_scale_thread_buf[Number<a_offset>{}] * b_scale_thread_buf[Number<b_offset>{}];
});
});
// Local prefill A1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
@@ -409,18 +406,16 @@ struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v1<BlockGemmPipelineS
// Local prefetch A1
block_sync_lds();
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * KPack / KGroup>{}),
a_thread_buf);
});
});
static_ford<Sequence<MRepeat, KRepeat, KGroup>>{}([&](auto mkk) {
constexpr auto m0 = Number<mkk[Number<0>{}]>{};
constexpr auto k0 = Number<mkk[Number<1>{}]>{};
constexpr auto kg0 = Number<mkk[Number<2>{}]>{};
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * KPack / KGroup>{}),
a_thread_buf);
});
// Initialize C
@@ -448,114 +443,104 @@ struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v1<BlockGemmPipelineS
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) {
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
.template AsType<AccDataType>()(Number<t>{}) = 0;
});
vector_type<AccDataType, 2> c_scale_thread_vec;
constexpr index_t cscale_offset =
CScaleThreadDesc{}.CalculateOffset(
make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat));
static_ford<Sequence<MRepeat, NRepeat, num_scale_k_block>>{}([&](auto mnk) {
constexpr auto m0 = Number<mnk[Number<0>{}]>{};
constexpr auto n0 = Number<mnk[Number<1>{}]>{};
constexpr auto kscale0 = Number<mnk[Number<2>{}]>{};
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
.template AsType<AccDataType>()(Number<t>{}) = 0;
});
vector_type<AccDataType, 2> c_scale_thread_vec;
constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset(
make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat));
c_scale_thread_vec.template AsType<AccDataType>()(Number<0>{}) =
c_scale_thread_buf[Number<cscale_offset>{}];
c_scale_thread_vec.template AsType<AccDataType>()(Number<1>{}) =
c_scale_thread_buf[Number<cscale_offset>{}];
c_scale_thread_vec.template AsType<AccDataType>()(Number<0>{}) =
c_scale_thread_buf[Number<cscale_offset>{}];
c_scale_thread_vec.template AsType<AccDataType>()(Number<1>{}) =
c_scale_thread_buf[Number<cscale_offset>{}];
static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0,
I0,
I0,
kscale0 * KRepeat / num_scale_k_block +
k0,
I0,
ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[mfma_reg_buf][Number<
b_thread_desc_.CalculateOffset(make_tuple(
n0,
I0,
kscale0 * KRepeat / num_scale_k_block + k0,
ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
});
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}(
[&](auto t) {
using pk_fma_type =
typename vector_type<AccDataType, 2>::type;
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()(t) =
__builtin_elementwise_fma(
c_thread_buf_per_scale
.GetVectorTypeReference(Number<0>{})
.template AsType<pk_fma_type>()[t],
c_scale_thread_vec
.template AsType<pk_fma_type>()[Number<0>{}],
c_thread_buf
.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()[t]);
});
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0,
I0,
I0,
kscale0 * KRepeat / num_scale_k_block + k0,
I0,
ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[mfma_reg_buf]
[Number<b_thread_desc_.CalculateOffset(make_tuple(
n0,
I0,
kscale0 * KRepeat / num_scale_k_block + k0,
ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
});
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) {
using pk_fma_type = typename vector_type<AccDataType, 2>::type;
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
.template AsType<pk_fma_type>()[t],
c_scale_thread_vec.template AsType<pk_fma_type>()[Number<0>{}],
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()[t]);
});
});
block_sync_lds();
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * KPack / KGroup>{}),
a_thread_buf);
});
});
static_ford<Sequence<MRepeat, KRepeat, KGroup>>{}([&](auto mkk) {
constexpr auto m0 = Number<mkk[Number<0>{}]>{};
constexpr auto k0 = Number<mkk[Number<1>{}]>{};
constexpr auto kg0 = Number<mkk[Number<2>{}]>{};
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * KPack / KGroup>{}),
a_thread_buf);
});
HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, num_scale_n_block, 1>{}([&](auto n0) {
static_for<0, num_scale_k_block, 1>{}([&](auto k0) {
constexpr index_t c_offset =
CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0));
constexpr index_t a_offset =
AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0));
constexpr index_t b_offset =
BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0));
static_ford<Sequence<MRepeat, num_scale_n_block, num_scale_k_block>>{}(
[&](auto mnk) {
constexpr auto m0 = Number<mnk[Number<0>{}]>{};
constexpr auto n0 = Number<mnk[Number<1>{}]>{};
constexpr auto k0 = Number<mnk[Number<2>{}]>{};
constexpr index_t c_offset =
CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0));
constexpr index_t a_offset =
AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0));
constexpr index_t b_offset =
BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0));
c_scale_thread_buf(Number<c_offset>{}) =
a_scale_thread_buf[Number<a_offset>{}] *
b_scale_thread_buf[Number<b_offset>{}];
});
c_scale_thread_buf(Number<c_offset>{}) =
a_scale_thread_buf[Number<a_offset>{}] *
b_scale_thread_buf[Number<b_offset>{}];
});
});
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_scale_thread_copy.Run(a_scale_grid_desc,
@@ -606,231 +591,207 @@ struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v1<BlockGemmPipelineS
block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) {
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
.template AsType<AccDataType>()(Number<t>{}) = 0;
});
vector_type<AccDataType, 2> c_scale_thread_vec;
constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset(
make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat));
static_ford<Sequence<MRepeat, NRepeat, num_scale_k_block>>{}([&](auto mnk) {
constexpr auto m0 = Number<mnk[Number<0>{}]>{};
constexpr auto n0 = Number<mnk[Number<1>{}]>{};
constexpr auto kscale0 = Number<mnk[Number<2>{}]>{};
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
.template AsType<AccDataType>()(Number<t>{}) = 0;
});
vector_type<AccDataType, 2> c_scale_thread_vec;
constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset(
make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat));
c_scale_thread_vec.template AsType<AccDataType>()(Number<0>{}) =
c_scale_thread_buf[Number<cscale_offset>{}];
c_scale_thread_vec.template AsType<AccDataType>()(Number<1>{}) =
c_scale_thread_buf[Number<cscale_offset>{}];
c_scale_thread_vec.template AsType<AccDataType>()(Number<0>{}) =
c_scale_thread_buf[Number<cscale_offset>{}];
c_scale_thread_vec.template AsType<AccDataType>()(Number<1>{}) =
c_scale_thread_buf[Number<cscale_offset>{}];
static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0,
I0,
I0,
kscale0 * KRepeat / num_scale_k_block + k0,
I0,
ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0,
I0,
kscale0 * KRepeat / num_scale_k_block + k0,
ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
});
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) {
using pk_fma_type = typename vector_type<AccDataType, 2>::type;
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
.template AsType<pk_fma_type>()[t],
c_scale_thread_vec.template AsType<pk_fma_type>()[Number<0>{}],
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()[t]);
});
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0,
I0,
I0,
kscale0 * KRepeat / num_scale_k_block + k0,
I0,
ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(make_tuple(
n0, I0, kscale0 * KRepeat / num_scale_k_block + k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
});
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) {
using pk_fma_type = typename vector_type<AccDataType, 2>::type;
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
.template AsType<pk_fma_type>()[t],
c_scale_thread_vec.template AsType<pk_fma_type>()[Number<0>{}],
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()[t]);
});
});
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, num_scale_n_block, 1>{}([&](auto n0) {
static_for<0, num_scale_k_block, 1>{}([&](auto k0) {
constexpr index_t c_offset =
CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0));
constexpr index_t a_offset =
AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0));
constexpr index_t b_offset =
BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0));
static_ford<Sequence<MRepeat, num_scale_n_block, num_scale_k_block>>{}([&](auto mnk) {
constexpr auto m0 = Number<mnk[Number<0>{}]>{};
constexpr auto n0 = Number<mnk[Number<1>{}]>{};
constexpr auto k0 = Number<mnk[Number<2>{}]>{};
constexpr index_t c_offset =
CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0));
constexpr index_t a_offset = AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0));
constexpr index_t b_offset = BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0));
c_scale_thread_buf(Number<c_offset>{}) =
a_scale_thread_buf[Number<a_offset>{}] *
b_scale_thread_buf[Number<b_offset>{}];
});
});
c_scale_thread_buf(Number<c_offset>{}) =
a_scale_thread_buf[Number<a_offset>{}] * b_scale_thread_buf[Number<b_offset>{}];
});
block_sync_lds();
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * KPack / KGroup>{}),
a_thread_buf);
});
});
static_ford<Sequence<MRepeat, KRepeat, KGroup>>{}([&](auto mkk) {
constexpr auto m0 = Number<mkk[Number<0>{}]>{};
constexpr auto k0 = Number<mkk[Number<1>{}]>{};
constexpr auto kg0 = Number<mkk[Number<2>{}]>{};
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * KPack / KGroup>{}),
a_thread_buf);
});
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) {
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
.template AsType<AccDataType>()(Number<t>{}) = 0;
});
vector_type<AccDataType, 2> c_scale_thread_vec;
constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset(
make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat));
static_ford<Sequence<MRepeat, NRepeat, num_scale_k_block>>{}([&](auto mnk) {
constexpr auto m0 = Number<mnk[Number<0>{}]>{};
constexpr auto n0 = Number<mnk[Number<1>{}]>{};
constexpr auto kscale0 = Number<mnk[Number<2>{}]>{};
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
.template AsType<AccDataType>()(Number<t>{}) = 0;
});
vector_type<AccDataType, 2> c_scale_thread_vec;
constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset(
make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat));
c_scale_thread_vec.template AsType<AccDataType>()(Number<0>{}) =
c_scale_thread_buf[Number<cscale_offset>{}];
c_scale_thread_vec.template AsType<AccDataType>()(Number<1>{}) =
c_scale_thread_buf[Number<cscale_offset>{}];
c_scale_thread_vec.template AsType<AccDataType>()(Number<0>{}) =
c_scale_thread_buf[Number<cscale_offset>{}];
c_scale_thread_vec.template AsType<AccDataType>()(Number<1>{}) =
c_scale_thread_buf[Number<cscale_offset>{}];
static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0,
I0,
I0,
kscale0 * KRepeat / num_scale_k_block + k0,
I0,
ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[I1][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0,
I0,
kscale0 * KRepeat / num_scale_k_block + k0,
ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
});
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) {
using pk_fma_type = typename vector_type<AccDataType, 2>::type;
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
.template AsType<pk_fma_type>()[t],
c_scale_thread_vec.template AsType<pk_fma_type>()[Number<0>{}],
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()[t]);
});
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0,
I0,
I0,
kscale0 * KRepeat / num_scale_k_block + k0,
I0,
ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[I1][Number<b_thread_desc_.CalculateOffset(make_tuple(
n0, I0, kscale0 * KRepeat / num_scale_k_block + k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
});
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) {
using pk_fma_type = typename vector_type<AccDataType, 2>::type;
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
.template AsType<pk_fma_type>()[t],
c_scale_thread_vec.template AsType<pk_fma_type>()[Number<0>{}],
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()[t]);
});
});
}
else if constexpr(TailNum == TailNumber::Odd)
{
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) {
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
.template AsType<AccDataType>()(Number<t>{}) = 0;
});
vector_type<AccDataType, 2> c_scale_thread_vec;
constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset(
make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat));
static_ford<Sequence<MRepeat, NRepeat, num_scale_k_block>>{}([&](auto mnk) {
constexpr auto m0 = Number<mnk[Number<0>{}]>{};
constexpr auto n0 = Number<mnk[Number<1>{}]>{};
constexpr auto kscale0 = Number<mnk[Number<2>{}]>{};
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
.template AsType<AccDataType>()(Number<t>{}) = 0;
});
vector_type<AccDataType, 2> c_scale_thread_vec;
constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset(
make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat));
c_scale_thread_vec.template AsType<AccDataType>()(Number<0>{}) =
c_scale_thread_buf[Number<cscale_offset>{}];
c_scale_thread_vec.template AsType<AccDataType>()(Number<1>{}) =
c_scale_thread_buf[Number<cscale_offset>{}];
c_scale_thread_vec.template AsType<AccDataType>()(Number<0>{}) =
c_scale_thread_buf[Number<cscale_offset>{}];
c_scale_thread_vec.template AsType<AccDataType>()(Number<1>{}) =
c_scale_thread_buf[Number<cscale_offset>{}];
static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0,
I0,
I0,
kscale0 * KRepeat / num_scale_k_block + k0,
I0,
ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0,
I0,
kscale0 * KRepeat / num_scale_k_block + k0,
ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
});
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) {
using pk_fma_type = typename vector_type<AccDataType, 2>::type;
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
.template AsType<pk_fma_type>()[t],
c_scale_thread_vec.template AsType<pk_fma_type>()[Number<0>{}],
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()[t]);
});
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0,
I0,
I0,
kscale0 * KRepeat / num_scale_k_block + k0,
I0,
ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(make_tuple(
n0, I0, kscale0 * KRepeat / num_scale_k_block + k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
});
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) {
using pk_fma_type = typename vector_type<AccDataType, 2>::type;
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
.template AsType<pk_fma_type>()[t],
c_scale_thread_vec.template AsType<pk_fma_type>()[Number<0>{}],
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()[t]);
});
});
}

View File

@@ -538,18 +538,16 @@ struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v3<BlockGemmPipelineS
// Local prefetch A1
block_sync_lds();
static_for<0, LocalPrefetchStages, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf.At(I0),
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * KPack / KGroup>{}),
a_thread_buf);
});
});
static_ford<Sequence<LocalPrefetchStages, KRepeat, KGroup>>{}([&](auto mkk) {
constexpr auto m0 = Number<mkk[Number<0>{}]>{};
constexpr auto k0 = Number<mkk[Number<1>{}]>{};
constexpr auto kg0 = Number<mkk[Number<2>{}]>{};
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf.At(I0),
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * KPack / KGroup>{}),
a_thread_buf);
});
#if 1
@@ -717,28 +715,28 @@ struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v3<BlockGemmPipelineS
? local_read_buf
: mfma_reg_buf;
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 2) % MRepeat>{},
I0,
I0,
Number<k0 * KGroup + kg0>{},
I0,
I0),
a_block_buf.At(Number<lds_buf>{}),
a_thread_desc_,
make_tuple(Number<(m0 + LocalPrefetchStages +
HotloopLocalBufSwitch * mfma_reg_buf) %
2>{},
I0,
I0,
k0,
I0,
Number<kg0 * KPack / KGroup>{}),
a_thread_buf);
});
static_ford<Sequence<KRepeat, KGroup>>{}([&](auto kk) {
constexpr auto k0 = Number<kk[Number<0>{}]>{};
constexpr auto kg0 = Number<kk[Number<1>{}]>{};
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 2) % MRepeat>{},
I0,
I0,
Number<k0 * KGroup + kg0>{},
I0,
I0),
a_block_buf.At(Number<lds_buf>{}),
a_thread_desc_,
make_tuple(Number<(m0 + LocalPrefetchStages +
HotloopLocalBufSwitch * mfma_reg_buf) %
2>{},
I0,
I0,
k0,
I0,
Number<kg0 * KPack / KGroup>{}),
a_thread_buf);
});
});
@@ -841,26 +839,25 @@ struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v3<BlockGemmPipelineS
constexpr auto lds_buf = m0.value >= (MRepeat - LocalPrefetchStages) ? I1 : I0;
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + LocalPrefetchStages) % MRepeat>{},
I0,
I0,
Number<k0 * KGroup + kg0>{},
I0,
I0),
a_block_buf.At(Number<lds_buf>{}),
a_thread_desc_,
make_tuple(Number<(m0 + LocalPrefetchStages) % 2>{},
I0,
I0,
k0,
I0,
Number<kg0 * KPack / KGroup>{}),
a_thread_buf);
});
static_ford<Sequence<KRepeat, KGroup>>{}([&](auto kk) {
constexpr auto k0 = Number<kk[Number<0>{}]>{};
constexpr auto kg0 = Number<kk[Number<1>{}]>{};
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + LocalPrefetchStages) % MRepeat>{},
I0,
I0,
Number<k0 * KGroup + kg0>{},
I0,
I0),
a_block_buf.At(Number<lds_buf>{}),
a_thread_desc_,
make_tuple(Number<(m0 + LocalPrefetchStages) % 2>{},
I0,
I0,
k0,
I0,
Number<kg0 * KPack / KGroup>{}),
a_thread_buf);
});
});
@@ -943,28 +940,27 @@ struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v3<BlockGemmPipelineS
if constexpr(m0.value < (MRepeat - LocalPrefetchStages))
{
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<m0 + LocalPrefetchStages>{},
I0,
I0,
Number<k0 * KGroup + kg0>{},
I0,
I0),
a_block_buf.At(I1),
a_thread_desc_,
make_tuple(
Number<(m0 + LocalPrefetchStages + HotloopLocalBufSwitch) %
2>{},
I0,
I0,
k0,
I0,
Number<kg0 * KPack / KGroup>{}),
a_thread_buf);
});
static_ford<Sequence<KRepeat, KGroup>>{}([&](auto kk) {
constexpr auto k0 = Number<kk[Number<0>{}]>{};
constexpr auto kg0 = Number<kk[Number<1>{}]>{};
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<m0 + LocalPrefetchStages>{},
I0,
I0,
Number<k0 * KGroup + kg0>{},
I0,
I0),
a_block_buf.At(I1),
a_thread_desc_,
make_tuple(
Number<(m0 + LocalPrefetchStages + HotloopLocalBufSwitch) % 2>{},
I0,
I0,
k0,
I0,
Number<kg0 * KPack / KGroup>{}),
a_thread_buf);
});
}
});
@@ -1042,22 +1038,22 @@ struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v3<BlockGemmPipelineS
if constexpr(m0.value < (MRepeat - LocalPrefetchStages))
{
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(
Number<m0 + 2>{}, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf.At(I0),
a_thread_desc_,
make_tuple(Number<(m0 + LocalPrefetchStages) % 2>{},
I0,
I0,
k0,
I0,
Number<kg0 * KPack / KGroup>{}),
a_thread_buf);
});
static_ford<Sequence<KRepeat, KGroup>>{}([&](auto kk) {
constexpr auto k0 = Number<kk[Number<0>{}]>{};
constexpr auto kg0 = Number<kk[Number<1>{}]>{};
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(
Number<m0 + 2>{}, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf.At(I0),
a_thread_desc_,
make_tuple(Number<(m0 + LocalPrefetchStages) % 2>{},
I0,
I0,
k0,
I0,
Number<kg0 * KPack / KGroup>{}),
a_thread_buf);
});
}
});

View File

@@ -375,25 +375,22 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v1<
constexpr auto num_scale_k_block = CScaleThreadDesc{}.GetLength(Number<0>{});
constexpr auto num_scale_m_block = CScaleThreadDesc{}.GetLength(Number<1>{});
constexpr auto num_scale_n_block = CScaleThreadDesc{}.GetLength(Number<2>{});
static_for<0, num_scale_m_block, 1>{}([&](auto m0) {
static_for<0, num_scale_n_block, 1>{}([&](auto n0) {
static_for<0, num_scale_k_block, 1>{}([&](auto k0) {
constexpr index_t c_offset =
CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0));
constexpr index_t a_offset =
AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0));
constexpr index_t b_offset =
BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0));
static_ford<Sequence<num_scale_m_block, num_scale_n_block, num_scale_k_block>>{}(
[&](auto mnk) {
constexpr auto m0 = Number<mnk[Number<0>{}]>{};
constexpr auto n0 = Number<mnk[Number<1>{}]>{};
constexpr auto k0 = Number<mnk[Number<2>{}]>{};
constexpr index_t c_offset =
CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0));
constexpr index_t a_offset = AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0));
constexpr index_t b_offset = BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0));
c_scale_thread_buf(Number<c_offset>{}) =
a_scale_thread_buf[Number<a_offset>{}] *
b_scale_thread_buf[Number<b_offset>{}];
c_scale_thread_buf_up(Number<c_offset>{}) =
a_scale_thread_buf[Number<a_offset>{}] *
b_scale_thread_buf_up[Number<b_offset>{}];
});
c_scale_thread_buf(Number<c_offset>{}) =
a_scale_thread_buf[Number<a_offset>{}] * b_scale_thread_buf[Number<b_offset>{}];
c_scale_thread_buf_up(Number<c_offset>{}) =
a_scale_thread_buf[Number<a_offset>{}] *
b_scale_thread_buf_up[Number<b_offset>{}];
});
});
// Local prefill A1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
@@ -450,18 +447,16 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v1<
// Local prefetch A1
block_sync_lds();
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * KPack / KGroup>{}),
a_thread_buf);
});
});
static_ford<Sequence<MRepeat, KRepeat, KGroup>>{}([&](auto mkk) {
constexpr auto m0 = Number<mkk[Number<0>{}]>{};
constexpr auto k0 = Number<mkk[Number<1>{}]>{};
constexpr auto kg0 = Number<mkk[Number<2>{}]>{};
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * KPack / KGroup>{}),
a_thread_buf);
});
// Initialize C
@@ -496,149 +491,134 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v1<
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) {
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
.template AsType<AccDataType>()(Number<t>{}) = 0;
c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{})
.template AsType<AccDataType>()(Number<t>{}) = 0;
});
vector_type<AccDataType, 2> c_scale_thread_vec;
vector_type<AccDataType, 2> c_scale_thread_vec_up;
constexpr index_t cscale_offset =
CScaleThreadDesc{}.CalculateOffset(
make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat));
static_ford<Sequence<MRepeat, NRepeat, num_scale_k_block>>{}([&](auto mnk) {
constexpr auto m0 = Number<mnk[Number<0>{}]>{};
constexpr auto n0 = Number<mnk[Number<1>{}]>{};
constexpr auto kscale0 = Number<mnk[Number<2>{}]>{};
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
.template AsType<AccDataType>()(Number<t>{}) = 0;
c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{})
.template AsType<AccDataType>()(Number<t>{}) = 0;
});
vector_type<AccDataType, 2> c_scale_thread_vec;
vector_type<AccDataType, 2> c_scale_thread_vec_up;
constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset(
make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat));
c_scale_thread_vec.template AsType<AccDataType>()(Number<0>{}) =
c_scale_thread_buf[Number<cscale_offset>{}];
c_scale_thread_vec.template AsType<AccDataType>()(Number<1>{}) =
c_scale_thread_buf[Number<cscale_offset>{}];
c_scale_thread_vec_up.template AsType<AccDataType>()(Number<0>{}) =
c_scale_thread_buf_up[Number<cscale_offset>{}];
c_scale_thread_vec_up.template AsType<AccDataType>()(Number<1>{}) =
c_scale_thread_buf_up[Number<cscale_offset>{}];
c_scale_thread_vec.template AsType<AccDataType>()(Number<0>{}) =
c_scale_thread_buf[Number<cscale_offset>{}];
c_scale_thread_vec.template AsType<AccDataType>()(Number<1>{}) =
c_scale_thread_buf[Number<cscale_offset>{}];
c_scale_thread_vec_up.template AsType<AccDataType>()(Number<0>{}) =
c_scale_thread_buf_up[Number<cscale_offset>{}];
c_scale_thread_vec_up.template AsType<AccDataType>()(Number<1>{}) =
c_scale_thread_buf_up[Number<cscale_offset>{}];
static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec_up;
static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec_up;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0,
I0,
I0,
kscale0 * KRepeat / num_scale_k_block +
k0,
I0,
ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[mfma_reg_buf][Number<
b_thread_desc_.CalculateOffset(make_tuple(
n0,
I0,
kscale0 * KRepeat / num_scale_k_block + k0,
ik))>{}];
b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
b_thread_bufs_up[mfma_reg_buf][Number<
b_thread_desc_.CalculateOffset(make_tuple(
n0,
I0,
kscale0 * KRepeat / num_scale_k_block + k0,
ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec_up.template AsType<mfma_input_type>(),
c_thread_buf_per_scale_up.GetVectorTypeReference(
Number<0>{}));
});
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}(
[&](auto t) {
using pk_fma_type =
typename vector_type<AccDataType, 2>::type;
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()(t) =
__builtin_elementwise_fma(
c_thread_buf_per_scale
.GetVectorTypeReference(Number<0>{})
.template AsType<pk_fma_type>()[t],
c_scale_thread_vec
.template AsType<pk_fma_type>()[Number<0>{}],
c_thread_buf
.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()[t]);
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()(t) =
__builtin_elementwise_fma(
c_thread_buf_per_scale_up
.GetVectorTypeReference(Number<0>{})
.template AsType<pk_fma_type>()[t],
c_scale_thread_vec_up
.template AsType<pk_fma_type>()[Number<0>{}],
c_thread_buf_up
.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()[t]);
});
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0,
I0,
I0,
kscale0 * KRepeat / num_scale_k_block + k0,
I0,
ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[mfma_reg_buf]
[Number<b_thread_desc_.CalculateOffset(make_tuple(
n0,
I0,
kscale0 * KRepeat / num_scale_k_block + k0,
ik))>{}];
b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
b_thread_bufs_up
[mfma_reg_buf][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0,
I0,
kscale0 * KRepeat / num_scale_k_block + k0,
ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec_up.template AsType<mfma_input_type>(),
c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{}));
});
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) {
using pk_fma_type = typename vector_type<AccDataType, 2>::type;
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
.template AsType<pk_fma_type>()[t],
c_scale_thread_vec.template AsType<pk_fma_type>()[Number<0>{}],
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()[t]);
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{})
.template AsType<pk_fma_type>()[t],
c_scale_thread_vec_up.template AsType<pk_fma_type>()[Number<0>{}],
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()[t]);
});
});
block_sync_lds();
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * KPack / KGroup>{}),
a_thread_buf);
});
});
static_ford<Sequence<MRepeat, KRepeat, KGroup>>{}([&](auto mkk) {
constexpr auto m0 = Number<mkk[Number<0>{}]>{};
constexpr auto k0 = Number<mkk[Number<1>{}]>{};
constexpr auto kg0 = Number<mkk[Number<2>{}]>{};
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * KPack / KGroup>{}),
a_thread_buf);
});
HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, num_scale_n_block, 1>{}([&](auto n0) {
static_for<0, num_scale_k_block, 1>{}([&](auto k0) {
constexpr index_t c_offset =
CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0));
constexpr index_t a_offset =
AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0));
constexpr index_t b_offset =
BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0));
static_ford<Sequence<MRepeat, num_scale_n_block, num_scale_k_block>>{}(
[&](auto mnk) {
constexpr auto m0 = Number<mnk[Number<0>{}]>{};
constexpr auto n0 = Number<mnk[Number<1>{}]>{};
constexpr auto k0 = Number<mnk[Number<2>{}]>{};
constexpr index_t c_offset =
CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0));
constexpr index_t a_offset =
AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0));
constexpr index_t b_offset =
BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0));
c_scale_thread_buf(Number<c_offset>{}) =
a_scale_thread_buf[Number<a_offset>{}] *
b_scale_thread_buf[Number<b_offset>{}];
c_scale_thread_buf_up(Number<c_offset>{}) =
a_scale_thread_buf[Number<a_offset>{}] *
b_scale_thread_buf_up[Number<b_offset>{}];
});
c_scale_thread_buf(Number<c_offset>{}) =
a_scale_thread_buf[Number<a_offset>{}] *
b_scale_thread_buf[Number<b_offset>{}];
c_scale_thread_buf_up(Number<c_offset>{}) =
a_scale_thread_buf[Number<a_offset>{}] *
b_scale_thread_buf_up[Number<b_offset>{}];
});
});
a_scale_thread_copy.Run(a_scale_grid_desc,
a_scale_grid_buf,
@@ -699,310 +679,277 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v1<
block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) {
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
.template AsType<AccDataType>()(Number<t>{}) = 0;
c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{})
.template AsType<AccDataType>()(Number<t>{}) = 0;
});
vector_type<AccDataType, 2> c_scale_thread_vec;
vector_type<AccDataType, 2> c_scale_thread_vec_up;
constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset(
make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat));
static_ford<Sequence<MRepeat, NRepeat, num_scale_k_block>>{}([&](auto mnk) {
constexpr auto m0 = Number<mnk[Number<0>{}]>{};
constexpr auto n0 = Number<mnk[Number<1>{}]>{};
constexpr auto kscale0 = Number<mnk[Number<2>{}]>{};
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
.template AsType<AccDataType>()(Number<t>{}) = 0;
c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{})
.template AsType<AccDataType>()(Number<t>{}) = 0;
});
vector_type<AccDataType, 2> c_scale_thread_vec;
vector_type<AccDataType, 2> c_scale_thread_vec_up;
constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset(
make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat));
c_scale_thread_vec.template AsType<AccDataType>()(Number<0>{}) =
c_scale_thread_buf[Number<cscale_offset>{}];
c_scale_thread_vec.template AsType<AccDataType>()(Number<1>{}) =
c_scale_thread_buf[Number<cscale_offset>{}];
c_scale_thread_vec_up.template AsType<AccDataType>()(Number<0>{}) =
c_scale_thread_buf_up[Number<cscale_offset>{}];
c_scale_thread_vec_up.template AsType<AccDataType>()(Number<1>{}) =
c_scale_thread_buf_up[Number<cscale_offset>{}];
c_scale_thread_vec.template AsType<AccDataType>()(Number<0>{}) =
c_scale_thread_buf[Number<cscale_offset>{}];
c_scale_thread_vec.template AsType<AccDataType>()(Number<1>{}) =
c_scale_thread_buf[Number<cscale_offset>{}];
c_scale_thread_vec_up.template AsType<AccDataType>()(Number<0>{}) =
c_scale_thread_buf_up[Number<cscale_offset>{}];
c_scale_thread_vec_up.template AsType<AccDataType>()(Number<1>{}) =
c_scale_thread_buf_up[Number<cscale_offset>{}];
static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec_up;
static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec_up;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0,
I0,
I0,
kscale0 * KRepeat / num_scale_k_block + k0,
I0,
ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0,
I0,
kscale0 * KRepeat / num_scale_k_block + k0,
ik))>{}];
b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
b_thread_bufs_up[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0,
I0,
kscale0 * KRepeat / num_scale_k_block + k0,
ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec_up.template AsType<mfma_input_type>(),
c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{}));
});
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) {
using pk_fma_type = typename vector_type<AccDataType, 2>::type;
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
.template AsType<pk_fma_type>()[t],
c_scale_thread_vec.template AsType<pk_fma_type>()[Number<0>{}],
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()[t]);
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{})
.template AsType<pk_fma_type>()[t],
c_scale_thread_vec_up.template AsType<pk_fma_type>()[Number<0>{}],
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()[t]);
});
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0,
I0,
I0,
kscale0 * KRepeat / num_scale_k_block + k0,
I0,
ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(make_tuple(
n0, I0, kscale0 * KRepeat / num_scale_k_block + k0, ik))>{}];
b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
b_thread_bufs_up[I0][Number<b_thread_desc_.CalculateOffset(make_tuple(
n0, I0, kscale0 * KRepeat / num_scale_k_block + k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec_up.template AsType<mfma_input_type>(),
c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{}));
});
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) {
using pk_fma_type = typename vector_type<AccDataType, 2>::type;
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
.template AsType<pk_fma_type>()[t],
c_scale_thread_vec.template AsType<pk_fma_type>()[Number<0>{}],
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()[t]);
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{})
.template AsType<pk_fma_type>()[t],
c_scale_thread_vec_up.template AsType<pk_fma_type>()[Number<0>{}],
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()[t]);
});
});
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, num_scale_n_block, 1>{}([&](auto n0) {
static_for<0, num_scale_k_block, 1>{}([&](auto k0) {
constexpr index_t c_offset =
CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0));
constexpr index_t a_offset =
AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0));
constexpr index_t b_offset =
BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0));
static_ford<Sequence<MRepeat, num_scale_n_block, num_scale_k_block>>{}([&](auto mnk) {
constexpr auto m0 = Number<mnk[Number<0>{}]>{};
constexpr auto n0 = Number<mnk[Number<1>{}]>{};
constexpr auto k0 = Number<mnk[Number<2>{}]>{};
constexpr index_t c_offset =
CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0));
constexpr index_t a_offset = AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0));
constexpr index_t b_offset = BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0));
c_scale_thread_buf(Number<c_offset>{}) =
a_scale_thread_buf[Number<a_offset>{}] *
b_scale_thread_buf[Number<b_offset>{}];
c_scale_thread_buf_up(Number<c_offset>{}) =
a_scale_thread_buf[Number<a_offset>{}] *
b_scale_thread_buf_up[Number<b_offset>{}];
});
});
c_scale_thread_buf(Number<c_offset>{}) =
a_scale_thread_buf[Number<a_offset>{}] * b_scale_thread_buf[Number<b_offset>{}];
c_scale_thread_buf_up(Number<c_offset>{}) =
a_scale_thread_buf[Number<a_offset>{}] *
b_scale_thread_buf_up[Number<b_offset>{}];
});
block_sync_lds();
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * KPack / KGroup>{}),
a_thread_buf);
});
});
static_ford<Sequence<MRepeat, KRepeat, KGroup>>{}([&](auto mkk) {
constexpr auto m0 = Number<mkk[Number<0>{}]>{};
constexpr auto k0 = Number<mkk[Number<1>{}]>{};
constexpr auto kg0 = Number<mkk[Number<2>{}]>{};
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * KPack / KGroup>{}),
a_thread_buf);
});
// __builtin_amdgcn_sched_barrier(0);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) {
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
.template AsType<AccDataType>()(Number<t>{}) = 0;
c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{})
.template AsType<AccDataType>()(Number<t>{}) = 0;
});
vector_type<AccDataType, 2> c_scale_thread_vec;
vector_type<AccDataType, 2> c_scale_thread_vec_up;
constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset(
make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat));
static_ford<Sequence<MRepeat, NRepeat, num_scale_k_block>>{}([&](auto mnk) {
constexpr auto m0 = Number<mnk[Number<0>{}]>{};
constexpr auto n0 = Number<mnk[Number<1>{}]>{};
constexpr auto kscale0 = Number<mnk[Number<2>{}]>{};
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
.template AsType<AccDataType>()(Number<t>{}) = 0;
c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{})
.template AsType<AccDataType>()(Number<t>{}) = 0;
});
vector_type<AccDataType, 2> c_scale_thread_vec;
vector_type<AccDataType, 2> c_scale_thread_vec_up;
constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset(
make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat));
c_scale_thread_vec.template AsType<AccDataType>()(Number<0>{}) =
c_scale_thread_buf[Number<cscale_offset>{}];
c_scale_thread_vec.template AsType<AccDataType>()(Number<1>{}) =
c_scale_thread_buf[Number<cscale_offset>{}];
c_scale_thread_vec_up.template AsType<AccDataType>()(Number<0>{}) =
c_scale_thread_buf_up[Number<cscale_offset>{}];
c_scale_thread_vec_up.template AsType<AccDataType>()(Number<1>{}) =
c_scale_thread_buf_up[Number<cscale_offset>{}];
c_scale_thread_vec.template AsType<AccDataType>()(Number<0>{}) =
c_scale_thread_buf[Number<cscale_offset>{}];
c_scale_thread_vec.template AsType<AccDataType>()(Number<1>{}) =
c_scale_thread_buf[Number<cscale_offset>{}];
c_scale_thread_vec_up.template AsType<AccDataType>()(Number<0>{}) =
c_scale_thread_buf_up[Number<cscale_offset>{}];
c_scale_thread_vec_up.template AsType<AccDataType>()(Number<1>{}) =
c_scale_thread_buf_up[Number<cscale_offset>{}];
static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec_up;
static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec_up;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0,
I0,
I0,
kscale0 * KRepeat / num_scale_k_block + k0,
I0,
ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[I1][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0,
I0,
kscale0 * KRepeat / num_scale_k_block + k0,
ik))>{}];
b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
b_thread_bufs_up[I1][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0,
I0,
kscale0 * KRepeat / num_scale_k_block + k0,
ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec_up.template AsType<mfma_input_type>(),
c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{}));
});
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) {
using pk_fma_type = typename vector_type<AccDataType, 2>::type;
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
.template AsType<pk_fma_type>()[t],
c_scale_thread_vec.template AsType<pk_fma_type>()[Number<0>{}],
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()[t]);
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{})
.template AsType<pk_fma_type>()[t],
c_scale_thread_vec_up.template AsType<pk_fma_type>()[Number<0>{}],
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()[t]);
});
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0,
I0,
I0,
kscale0 * KRepeat / num_scale_k_block + k0,
I0,
ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[I1][Number<b_thread_desc_.CalculateOffset(make_tuple(
n0, I0, kscale0 * KRepeat / num_scale_k_block + k0, ik))>{}];
b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
b_thread_bufs_up[I1][Number<b_thread_desc_.CalculateOffset(make_tuple(
n0, I0, kscale0 * KRepeat / num_scale_k_block + k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec_up.template AsType<mfma_input_type>(),
c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{}));
});
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) {
using pk_fma_type = typename vector_type<AccDataType, 2>::type;
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
.template AsType<pk_fma_type>()[t],
c_scale_thread_vec.template AsType<pk_fma_type>()[Number<0>{}],
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()[t]);
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{})
.template AsType<pk_fma_type>()[t],
c_scale_thread_vec_up.template AsType<pk_fma_type>()[Number<0>{}],
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()[t]);
});
});
}
else if constexpr(TailNum == TailNumber::Odd)
{
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) {
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
.template AsType<AccDataType>()(Number<t>{}) = 0;
c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{})
.template AsType<AccDataType>()(Number<t>{}) = 0;
});
vector_type<AccDataType, 2> c_scale_thread_vec;
vector_type<AccDataType, 2> c_scale_thread_vec_up;
constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset(
make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat));
static_ford<Sequence<MRepeat, NRepeat, num_scale_k_block>>{}([&](auto mnk) {
constexpr auto m0 = Number<mnk[Number<0>{}]>{};
constexpr auto n0 = Number<mnk[Number<1>{}]>{};
constexpr auto kscale0 = Number<mnk[Number<2>{}]>{};
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
.template AsType<AccDataType>()(Number<t>{}) = 0;
c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{})
.template AsType<AccDataType>()(Number<t>{}) = 0;
});
vector_type<AccDataType, 2> c_scale_thread_vec;
vector_type<AccDataType, 2> c_scale_thread_vec_up;
constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset(
make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat));
c_scale_thread_vec.template AsType<AccDataType>()(Number<0>{}) =
c_scale_thread_buf[Number<cscale_offset>{}];
c_scale_thread_vec.template AsType<AccDataType>()(Number<1>{}) =
c_scale_thread_buf[Number<cscale_offset>{}];
c_scale_thread_vec_up.template AsType<AccDataType>()(Number<0>{}) =
c_scale_thread_buf_up[Number<cscale_offset>{}];
c_scale_thread_vec_up.template AsType<AccDataType>()(Number<1>{}) =
c_scale_thread_buf_up[Number<cscale_offset>{}];
c_scale_thread_vec.template AsType<AccDataType>()(Number<0>{}) =
c_scale_thread_buf[Number<cscale_offset>{}];
c_scale_thread_vec.template AsType<AccDataType>()(Number<1>{}) =
c_scale_thread_buf[Number<cscale_offset>{}];
c_scale_thread_vec_up.template AsType<AccDataType>()(Number<0>{}) =
c_scale_thread_buf_up[Number<cscale_offset>{}];
c_scale_thread_vec_up.template AsType<AccDataType>()(Number<1>{}) =
c_scale_thread_buf_up[Number<cscale_offset>{}];
static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec_up;
static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec_up;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0,
I0,
I0,
kscale0 * KRepeat / num_scale_k_block + k0,
I0,
ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0,
I0,
kscale0 * KRepeat / num_scale_k_block + k0,
ik))>{}];
b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
b_thread_bufs_up[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0,
I0,
kscale0 * KRepeat / num_scale_k_block + k0,
ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec_up.template AsType<mfma_input_type>(),
c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{}));
});
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) {
using pk_fma_type = typename vector_type<AccDataType, 2>::type;
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
.template AsType<pk_fma_type>()[t],
c_scale_thread_vec.template AsType<pk_fma_type>()[Number<0>{}],
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()[t]);
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{})
.template AsType<pk_fma_type>()[t],
c_scale_thread_vec_up.template AsType<pk_fma_type>()[Number<0>{}],
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()[t]);
});
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0,
I0,
I0,
kscale0 * KRepeat / num_scale_k_block + k0,
I0,
ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(make_tuple(
n0, I0, kscale0 * KRepeat / num_scale_k_block + k0, ik))>{}];
b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
b_thread_bufs_up[I0][Number<b_thread_desc_.CalculateOffset(make_tuple(
n0, I0, kscale0 * KRepeat / num_scale_k_block + k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec_up.template AsType<mfma_input_type>(),
c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{}));
});
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) {
using pk_fma_type = typename vector_type<AccDataType, 2>::type;
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
.template AsType<pk_fma_type>()[t],
c_scale_thread_vec.template AsType<pk_fma_type>()[Number<0>{}],
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()[t]);
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{})
.template AsType<pk_fma_type>()[t],
c_scale_thread_vec_up.template AsType<pk_fma_type>()[Number<0>{}],
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()[t]);
});
});
}

View File

@@ -569,17 +569,16 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v3<
// Local prefetch A1
block_sync_lds();
static_for<0, 2, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf.At(I0),
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
});
static_ford<Sequence<2, KRepeat, KGroup>>{}([&](auto mkk) {
constexpr auto m0 = Number<mkk[Number<0>{}]>{};
constexpr auto k0 = Number<mkk[Number<1>{}]>{};
constexpr auto kg0 = Number<mkk[Number<2>{}]>{};
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf.At(I0),
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
__builtin_amdgcn_sched_barrier(0);
@@ -729,80 +728,80 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v3<
{
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 2) % MRepeat>{},
I0,
I0,
Number<k0 * KGroup + kg0>{},
I0,
I0),
a_block_buf.At(local_read_buf),
a_thread_desc_,
make_tuple(
Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) %
2>{},
I0,
I0,
k0,
I0,
Number<kg0 * A_K1>{}),
a_thread_buf);
});
static_ford<Sequence<KRepeat, KGroup>>{}([&](auto kk) {
constexpr auto k0 = Number<kk[Number<0>{}]>{};
constexpr auto kg0 = Number<kk[Number<1>{}]>{};
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 2) % MRepeat>{},
I0,
I0,
Number<k0 * KGroup + kg0>{},
I0,
I0),
a_block_buf.At(local_read_buf),
a_thread_desc_,
make_tuple(
Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) %
2>{},
I0,
I0,
k0,
I0,
Number<kg0 * A_K1>{}),
a_thread_buf);
});
}
else if constexpr(m0.value == (MRepeat - 1))
{
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 2) % MRepeat>{},
I0,
I0,
Number<k0 * KGroup + kg0>{},
I0,
I0),
a_block_buf.At(local_read_buf),
a_thread_desc_,
make_tuple(
Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) %
2>{},
I0,
I0,
k0,
I0,
Number<kg0 * A_K1>{}),
a_thread_buf);
});
static_ford<Sequence<KRepeat, KGroup>>{}([&](auto kk) {
constexpr auto k0 = Number<kk[Number<0>{}]>{};
constexpr auto kg0 = Number<kk[Number<1>{}]>{};
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 2) % MRepeat>{},
I0,
I0,
Number<k0 * KGroup + kg0>{},
I0,
I0),
a_block_buf.At(local_read_buf),
a_thread_desc_,
make_tuple(
Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) %
2>{},
I0,
I0,
k0,
I0,
Number<kg0 * A_K1>{}),
a_thread_buf);
});
}
else
{
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 2) % MRepeat>{},
I0,
I0,
Number<k0 * KGroup + kg0>{},
I0,
I0),
a_block_buf.At(mfma_reg_buf),
a_thread_desc_,
make_tuple(
Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) %
2>{},
I0,
I0,
k0,
I0,
Number<kg0 * A_K1>{}),
a_thread_buf);
});
static_ford<Sequence<KRepeat, KGroup>>{}([&](auto kk) {
constexpr auto k0 = Number<kk[Number<0>{}]>{};
constexpr auto kg0 = Number<kk[Number<1>{}]>{};
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 2) % MRepeat>{},
I0,
I0,
Number<k0 * KGroup + kg0>{},
I0,
I0),
a_block_buf.At(mfma_reg_buf),
a_thread_desc_,
make_tuple(
Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) %
2>{},
I0,
I0,
k0,
I0,
Number<kg0 * A_K1>{}),
a_thread_buf);
});
}
});
@@ -916,62 +915,62 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v3<
{
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 2) % MRepeat>{},
I0,
I0,
Number<k0 * KGroup + kg0>{},
I0,
I0),
a_block_buf.At(I1),
a_thread_desc_,
make_tuple(
Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
static_ford<Sequence<KRepeat, KGroup>>{}([&](auto kk) {
constexpr auto k0 = Number<kk[Number<0>{}]>{};
constexpr auto kg0 = Number<kk[Number<1>{}]>{};
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 2) % MRepeat>{},
I0,
I0,
Number<k0 * KGroup + kg0>{},
I0,
I0),
a_block_buf.At(I1),
a_thread_desc_,
make_tuple(
Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
}
else if constexpr(m0.value == (MRepeat - 1))
{
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 2) % MRepeat>{},
I0,
I0,
Number<k0 * KGroup + kg0>{},
I0,
I0),
a_block_buf.At(I1),
a_thread_desc_,
make_tuple(
Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
static_ford<Sequence<KRepeat, KGroup>>{}([&](auto kk) {
constexpr auto k0 = Number<kk[Number<0>{}]>{};
constexpr auto kg0 = Number<kk[Number<1>{}]>{};
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 2) % MRepeat>{},
I0,
I0,
Number<k0 * KGroup + kg0>{},
I0,
I0),
a_block_buf.At(I1),
a_thread_desc_,
make_tuple(
Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
}
else
{
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 2) % MRepeat>{},
I0,
I0,
Number<k0 * KGroup + kg0>{},
I0,
I0),
a_block_buf.At(I0),
a_thread_desc_,
make_tuple(
Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
static_ford<Sequence<KRepeat, KGroup>>{}([&](auto kk) {
constexpr auto k0 = Number<kk[Number<0>{}]>{};
constexpr auto kg0 = Number<kk[Number<1>{}]>{};
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 2) % MRepeat>{},
I0,
I0,
Number<k0 * KGroup + kg0>{},
I0,
I0),
a_block_buf.At(I0),
a_thread_desc_,
make_tuple(
Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
}
});
@@ -1058,22 +1057,22 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v3<
if constexpr(m0.value < (MRepeat - 2))
{
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(
Number<m0 + 2>{}, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf.At(I1),
a_thread_desc_,
make_tuple(Number<(m0 + 2 + HotloopLocalBufSwitch) % 2>{},
I0,
I0,
k0,
I0,
Number<kg0 * A_K1>{}),
a_thread_buf);
});
static_ford<Sequence<KRepeat, KGroup>>{}([&](auto kk) {
constexpr auto k0 = Number<kk[Number<0>{}]>{};
constexpr auto kg0 = Number<kk[Number<1>{}]>{};
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(
Number<m0 + 2>{}, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf.At(I1),
a_thread_desc_,
make_tuple(Number<(m0 + 2 + HotloopLocalBufSwitch) % 2>{},
I0,
I0,
k0,
I0,
Number<kg0 * A_K1>{}),
a_thread_buf);
});
}
});
@@ -1156,18 +1155,18 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v3<
if constexpr(m0.value < (MRepeat - 2))
{
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(
Number<m0 + 2>{}, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf.At(I0),
a_thread_desc_,
make_tuple(
Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
static_ford<Sequence<KRepeat, KGroup>>{}([&](auto kk) {
constexpr auto k0 = Number<kk[Number<0>{}]>{};
constexpr auto kg0 = Number<kk[Number<1>{}]>{};
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(
Number<m0 + 2>{}, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf.At(I0),
a_thread_desc_,
make_tuple(
Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
}
});

View File

@@ -343,22 +343,19 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v1<
constexpr auto num_scale_m_block = CScaleThreadDesc{}.GetLength(Number<1>{});
constexpr auto num_scale_n_block = CScaleThreadDesc{}.GetLength(Number<2>{});
static_for<0, num_scale_m_block, 1>{}([&](auto m0) {
static_for<0, num_scale_n_block, 1>{}([&](auto n0) {
static_for<0, num_scale_k_block, 1>{}([&](auto k0) {
constexpr index_t c_offset =
CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0));
constexpr index_t a_offset =
AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0));
constexpr index_t b_offset =
BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0));
static_ford<Sequence<num_scale_m_block, num_scale_n_block, num_scale_k_block>>{}(
[&](auto mnk) {
constexpr auto m0 = Number<mnk[Number<0>{}]>{};
constexpr auto n0 = Number<mnk[Number<1>{}]>{};
constexpr auto k0 = Number<mnk[Number<2>{}]>{};
constexpr index_t c_offset =
CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0));
constexpr index_t a_offset = AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0));
constexpr index_t b_offset = BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0));
c_scale_thread_buf(Number<c_offset>{}) =
a_scale_thread_buf[Number<a_offset>{}] *
b_scale_thread_buf[Number<b_offset>{}];
});
c_scale_thread_buf(Number<c_offset>{}) =
a_scale_thread_buf[Number<a_offset>{}] * b_scale_thread_buf[Number<b_offset>{}];
});
});
__builtin_amdgcn_sched_barrier(0);
// Local prefill A1
@@ -402,18 +399,16 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v1<
// Local prefetch A1
block_sync_lds();
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * KPack / KGroup>{}),
a_thread_buf);
});
});
static_ford<Sequence<MRepeat, KRepeat, KGroup>>{}([&](auto mkk) {
constexpr auto m0 = Number<mkk[Number<0>{}]>{};
constexpr auto k0 = Number<mkk[Number<1>{}]>{};
constexpr auto kg0 = Number<mkk[Number<2>{}]>{};
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * KPack / KGroup>{}),
a_thread_buf);
});
// Initialize C
@@ -441,115 +436,105 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v1<
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) {
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
.template AsType<AccDataType>()(Number<t>{}) = 0;
});
vector_type<AccDataType, 2> c_scale_thread_vec;
constexpr index_t cscale_offset =
CScaleThreadDesc{}.CalculateOffset(
make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat));
static_ford<Sequence<MRepeat, NRepeat, num_scale_k_block>>{}([&](auto mnk) {
constexpr auto m0 = Number<mnk[Number<0>{}]>{};
constexpr auto n0 = Number<mnk[Number<1>{}]>{};
constexpr auto kscale0 = Number<mnk[Number<2>{}]>{};
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
.template AsType<AccDataType>()(Number<t>{}) = 0;
});
vector_type<AccDataType, 2> c_scale_thread_vec;
constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset(
make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat));
c_scale_thread_vec.template AsType<AccDataType>()(Number<0>{}) =
c_scale_thread_buf[Number<cscale_offset>{}];
c_scale_thread_vec.template AsType<AccDataType>()(Number<1>{}) =
c_scale_thread_buf[Number<cscale_offset>{}];
c_scale_thread_vec.template AsType<AccDataType>()(Number<0>{}) =
c_scale_thread_buf[Number<cscale_offset>{}];
c_scale_thread_vec.template AsType<AccDataType>()(Number<1>{}) =
c_scale_thread_buf[Number<cscale_offset>{}];
static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0,
I0,
I0,
kscale0 * KRepeat / num_scale_k_block +
k0,
I0,
ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[mfma_reg_buf][Number<
b_thread_desc_.CalculateOffset(make_tuple(
n0,
I0,
kscale0 * KRepeat / num_scale_k_block + k0,
ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
});
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}(
[&](auto t) {
using pk_fma_type =
typename vector_type<AccDataType, 2>::type;
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()(t) =
__builtin_elementwise_fma(
c_thread_buf_per_scale
.GetVectorTypeReference(Number<0>{})
.template AsType<pk_fma_type>()[t],
c_scale_thread_vec
.template AsType<pk_fma_type>()[Number<0>{}],
c_thread_buf
.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()[t]);
});
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0,
I0,
I0,
kscale0 * KRepeat / num_scale_k_block + k0,
I0,
ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[mfma_reg_buf]
[Number<b_thread_desc_.CalculateOffset(make_tuple(
n0,
I0,
kscale0 * KRepeat / num_scale_k_block + k0,
ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
});
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) {
using pk_fma_type = typename vector_type<AccDataType, 2>::type;
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
.template AsType<pk_fma_type>()[t],
c_scale_thread_vec.template AsType<pk_fma_type>()[Number<0>{}],
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()[t]);
});
});
block_sync_lds();
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * KPack / KGroup>{}),
a_thread_buf);
});
});
static_ford<Sequence<MRepeat, KRepeat, KGroup>>{}([&](auto mkk) {
constexpr auto m0 = Number<mkk[Number<0>{}]>{};
constexpr auto k0 = Number<mkk[Number<1>{}]>{};
constexpr auto kg0 = Number<mkk[Number<2>{}]>{};
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * KPack / KGroup>{}),
a_thread_buf);
});
HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, num_scale_n_block, 1>{}([&](auto n0) {
static_for<0, num_scale_k_block, 1>{}([&](auto k0) {
constexpr index_t c_offset =
CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0));
constexpr index_t a_offset =
AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0));
constexpr index_t b_offset =
BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0));
static_ford<Sequence<MRepeat, num_scale_n_block, num_scale_k_block>>{}(
[&](auto mnk) {
constexpr auto m0 = Number<mnk[Number<0>{}]>{};
constexpr auto n0 = Number<mnk[Number<1>{}]>{};
constexpr auto k0 = Number<mnk[Number<2>{}]>{};
constexpr index_t c_offset =
CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0));
constexpr index_t a_offset =
AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0));
constexpr index_t b_offset =
BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0));
c_scale_thread_buf(Number<c_offset>{}) =
a_scale_thread_buf[Number<a_offset>{}] *
b_scale_thread_buf[Number<b_offset>{}];
});
c_scale_thread_buf(Number<c_offset>{}) =
a_scale_thread_buf[Number<a_offset>{}] *
b_scale_thread_buf[Number<b_offset>{}];
});
});
__builtin_amdgcn_sched_barrier(0);
a_scale_thread_copy.Run(a_scale_grid_desc,
@@ -597,233 +582,209 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v1<
block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) {
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
.template AsType<AccDataType>()(Number<t>{}) = 0;
});
vector_type<AccDataType, 2> c_scale_thread_vec;
constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset(
make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat));
static_ford<Sequence<MRepeat, NRepeat, num_scale_k_block>>{}([&](auto mnk) {
constexpr auto m0 = Number<mnk[Number<0>{}]>{};
constexpr auto n0 = Number<mnk[Number<1>{}]>{};
constexpr auto kscale0 = Number<mnk[Number<2>{}]>{};
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
.template AsType<AccDataType>()(Number<t>{}) = 0;
});
vector_type<AccDataType, 2> c_scale_thread_vec;
constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset(
make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat));
c_scale_thread_vec.template AsType<AccDataType>()(Number<0>{}) =
c_scale_thread_buf[Number<cscale_offset>{}];
c_scale_thread_vec.template AsType<AccDataType>()(Number<1>{}) =
c_scale_thread_buf[Number<cscale_offset>{}];
c_scale_thread_vec.template AsType<AccDataType>()(Number<0>{}) =
c_scale_thread_buf[Number<cscale_offset>{}];
c_scale_thread_vec.template AsType<AccDataType>()(Number<1>{}) =
c_scale_thread_buf[Number<cscale_offset>{}];
static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0,
I0,
I0,
kscale0 * KRepeat / num_scale_k_block + k0,
I0,
ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0,
I0,
kscale0 * KRepeat / num_scale_k_block + k0,
ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
});
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) {
using pk_fma_type = typename vector_type<AccDataType, 2>::type;
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
.template AsType<pk_fma_type>()[t],
c_scale_thread_vec.template AsType<pk_fma_type>()[Number<0>{}],
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()[t]);
});
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0,
I0,
I0,
kscale0 * KRepeat / num_scale_k_block + k0,
I0,
ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(make_tuple(
n0, I0, kscale0 * KRepeat / num_scale_k_block + k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
});
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) {
using pk_fma_type = typename vector_type<AccDataType, 2>::type;
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
.template AsType<pk_fma_type>()[t],
c_scale_thread_vec.template AsType<pk_fma_type>()[Number<0>{}],
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()[t]);
});
});
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, num_scale_n_block, 1>{}([&](auto n0) {
static_for<0, num_scale_k_block, 1>{}([&](auto k0) {
constexpr index_t c_offset =
CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0));
constexpr index_t a_offset =
AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0));
constexpr index_t b_offset =
BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0));
static_ford<Sequence<MRepeat, num_scale_n_block, num_scale_k_block>>{}([&](auto mnk) {
constexpr auto m0 = Number<mnk[Number<0>{}]>{};
constexpr auto n0 = Number<mnk[Number<1>{}]>{};
constexpr auto k0 = Number<mnk[Number<2>{}]>{};
constexpr index_t c_offset =
CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0));
constexpr index_t a_offset = AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0));
constexpr index_t b_offset = BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0));
c_scale_thread_buf(Number<c_offset>{}) =
a_scale_thread_buf[Number<a_offset>{}] *
b_scale_thread_buf[Number<b_offset>{}];
});
});
c_scale_thread_buf(Number<c_offset>{}) =
a_scale_thread_buf[Number<a_offset>{}] * b_scale_thread_buf[Number<b_offset>{}];
});
block_sync_lds();
__builtin_amdgcn_sched_barrier(0);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * KPack / KGroup>{}),
a_thread_buf);
});
});
static_ford<Sequence<MRepeat, KRepeat, KGroup>>{}([&](auto mkk) {
constexpr auto m0 = Number<mkk[Number<0>{}]>{};
constexpr auto k0 = Number<mkk[Number<1>{}]>{};
constexpr auto kg0 = Number<mkk[Number<2>{}]>{};
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * KPack / KGroup>{}),
a_thread_buf);
});
// __builtin_amdgcn_sched_barrier(0);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) {
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
.template AsType<AccDataType>()(Number<t>{}) = 0;
});
vector_type<AccDataType, 2> c_scale_thread_vec;
constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset(
make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat));
static_ford<Sequence<MRepeat, NRepeat, num_scale_k_block>>{}([&](auto mnk) {
constexpr auto m0 = Number<mnk[Number<0>{}]>{};
constexpr auto n0 = Number<mnk[Number<1>{}]>{};
constexpr auto kscale0 = Number<mnk[Number<2>{}]>{};
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
.template AsType<AccDataType>()(Number<t>{}) = 0;
});
vector_type<AccDataType, 2> c_scale_thread_vec;
constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset(
make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat));
c_scale_thread_vec.template AsType<AccDataType>()(Number<0>{}) =
c_scale_thread_buf[Number<cscale_offset>{}];
c_scale_thread_vec.template AsType<AccDataType>()(Number<1>{}) =
c_scale_thread_buf[Number<cscale_offset>{}];
c_scale_thread_vec.template AsType<AccDataType>()(Number<0>{}) =
c_scale_thread_buf[Number<cscale_offset>{}];
c_scale_thread_vec.template AsType<AccDataType>()(Number<1>{}) =
c_scale_thread_buf[Number<cscale_offset>{}];
static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0,
I0,
I0,
kscale0 * KRepeat / num_scale_k_block + k0,
I0,
ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[I1][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0,
I0,
kscale0 * KRepeat / num_scale_k_block + k0,
ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
});
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) {
using pk_fma_type = typename vector_type<AccDataType, 2>::type;
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
.template AsType<pk_fma_type>()[t],
c_scale_thread_vec.template AsType<pk_fma_type>()[Number<0>{}],
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()[t]);
});
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0,
I0,
I0,
kscale0 * KRepeat / num_scale_k_block + k0,
I0,
ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[I1][Number<b_thread_desc_.CalculateOffset(make_tuple(
n0, I0, kscale0 * KRepeat / num_scale_k_block + k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
});
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) {
using pk_fma_type = typename vector_type<AccDataType, 2>::type;
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
.template AsType<pk_fma_type>()[t],
c_scale_thread_vec.template AsType<pk_fma_type>()[Number<0>{}],
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()[t]);
});
});
}
else if constexpr(TailNum == TailNumber::Odd)
{
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) {
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
.template AsType<AccDataType>()(Number<t>{}) = 0;
});
vector_type<AccDataType, 2> c_scale_thread_vec;
constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset(
make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat));
static_ford<Sequence<MRepeat, NRepeat, num_scale_k_block>>{}([&](auto mnk) {
constexpr auto m0 = Number<mnk[Number<0>{}]>{};
constexpr auto n0 = Number<mnk[Number<1>{}]>{};
constexpr auto kscale0 = Number<mnk[Number<2>{}]>{};
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
.template AsType<AccDataType>()(Number<t>{}) = 0;
});
vector_type<AccDataType, 2> c_scale_thread_vec;
constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset(
make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat));
c_scale_thread_vec.template AsType<AccDataType>()(Number<0>{}) =
c_scale_thread_buf[Number<cscale_offset>{}];
c_scale_thread_vec.template AsType<AccDataType>()(Number<1>{}) =
c_scale_thread_buf[Number<cscale_offset>{}];
c_scale_thread_vec.template AsType<AccDataType>()(Number<0>{}) =
c_scale_thread_buf[Number<cscale_offset>{}];
c_scale_thread_vec.template AsType<AccDataType>()(Number<1>{}) =
c_scale_thread_buf[Number<cscale_offset>{}];
static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0,
I0,
I0,
kscale0 * KRepeat / num_scale_k_block + k0,
I0,
ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0,
I0,
kscale0 * KRepeat / num_scale_k_block + k0,
ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
});
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) {
using pk_fma_type = typename vector_type<AccDataType, 2>::type;
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
.template AsType<pk_fma_type>()[t],
c_scale_thread_vec.template AsType<pk_fma_type>()[Number<0>{}],
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()[t]);
});
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0,
I0,
I0,
kscale0 * KRepeat / num_scale_k_block + k0,
I0,
ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(make_tuple(
n0, I0, kscale0 * KRepeat / num_scale_k_block + k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
});
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) {
using pk_fma_type = typename vector_type<AccDataType, 2>::type;
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
.template AsType<pk_fma_type>()[t],
c_scale_thread_vec.template AsType<pk_fma_type>()[Number<0>{}],
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()[t]);
});
});
}

View File

@@ -526,17 +526,16 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v3<
// Local prefetch A1
block_sync_lds();
static_for<0, 2, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf.At(I0),
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
});
static_ford<Sequence<2, KRepeat, KGroup>>{}([&](auto mkk) {
constexpr auto m0 = Number<mkk[Number<0>{}]>{};
constexpr auto k0 = Number<mkk[Number<1>{}]>{};
constexpr auto kg0 = Number<mkk[Number<2>{}]>{};
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf.At(I0),
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
#if 0
@@ -672,80 +671,80 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v3<
{
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 2) % MRepeat>{},
I0,
I0,
Number<k0 * KGroup + kg0>{},
I0,
I0),
a_block_buf.At(local_read_buf),
a_thread_desc_,
make_tuple(
Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) %
2>{},
I0,
I0,
k0,
I0,
Number<kg0 * A_K1>{}),
a_thread_buf);
});
static_ford<Sequence<KRepeat, KGroup>>{}([&](auto kk) {
constexpr auto k0 = Number<kk[Number<0>{}]>{};
constexpr auto kg0 = Number<kk[Number<1>{}]>{};
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 2) % MRepeat>{},
I0,
I0,
Number<k0 * KGroup + kg0>{},
I0,
I0),
a_block_buf.At(local_read_buf),
a_thread_desc_,
make_tuple(
Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) %
2>{},
I0,
I0,
k0,
I0,
Number<kg0 * A_K1>{}),
a_thread_buf);
});
}
else if constexpr(m0.value == (MRepeat - 1))
{
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 2) % MRepeat>{},
I0,
I0,
Number<k0 * KGroup + kg0>{},
I0,
I0),
a_block_buf.At(local_read_buf),
a_thread_desc_,
make_tuple(
Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) %
2>{},
I0,
I0,
k0,
I0,
Number<kg0 * A_K1>{}),
a_thread_buf);
});
static_ford<Sequence<KRepeat, KGroup>>{}([&](auto kk) {
constexpr auto k0 = Number<kk[Number<0>{}]>{};
constexpr auto kg0 = Number<kk[Number<1>{}]>{};
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 2) % MRepeat>{},
I0,
I0,
Number<k0 * KGroup + kg0>{},
I0,
I0),
a_block_buf.At(local_read_buf),
a_thread_desc_,
make_tuple(
Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) %
2>{},
I0,
I0,
k0,
I0,
Number<kg0 * A_K1>{}),
a_thread_buf);
});
}
else
{
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 2) % MRepeat>{},
I0,
I0,
Number<k0 * KGroup + kg0>{},
I0,
I0),
a_block_buf.At(mfma_reg_buf),
a_thread_desc_,
make_tuple(
Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) %
2>{},
I0,
I0,
k0,
I0,
Number<kg0 * A_K1>{}),
a_thread_buf);
});
static_ford<Sequence<KRepeat, KGroup>>{}([&](auto kk) {
constexpr auto k0 = Number<kk[Number<0>{}]>{};
constexpr auto kg0 = Number<kk[Number<1>{}]>{};
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 2) % MRepeat>{},
I0,
I0,
Number<k0 * KGroup + kg0>{},
I0,
I0),
a_block_buf.At(mfma_reg_buf),
a_thread_desc_,
make_tuple(
Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) %
2>{},
I0,
I0,
k0,
I0,
Number<kg0 * A_K1>{}),
a_thread_buf);
});
}
});
@@ -830,62 +829,62 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v3<
{
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 2) % MRepeat>{},
I0,
I0,
Number<k0 * KGroup + kg0>{},
I0,
I0),
a_block_buf.At(I1),
a_thread_desc_,
make_tuple(
Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
static_ford<Sequence<KRepeat, KGroup>>{}([&](auto kk) {
constexpr auto k0 = Number<kk[Number<0>{}]>{};
constexpr auto kg0 = Number<kk[Number<1>{}]>{};
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 2) % MRepeat>{},
I0,
I0,
Number<k0 * KGroup + kg0>{},
I0,
I0),
a_block_buf.At(I1),
a_thread_desc_,
make_tuple(
Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
}
else if constexpr(m0.value == (MRepeat - 1))
{
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 2) % MRepeat>{},
I0,
I0,
Number<k0 * KGroup + kg0>{},
I0,
I0),
a_block_buf.At(I1),
a_thread_desc_,
make_tuple(
Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
static_ford<Sequence<KRepeat, KGroup>>{}([&](auto kk) {
constexpr auto k0 = Number<kk[Number<0>{}]>{};
constexpr auto kg0 = Number<kk[Number<1>{}]>{};
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 2) % MRepeat>{},
I0,
I0,
Number<k0 * KGroup + kg0>{},
I0,
I0),
a_block_buf.At(I1),
a_thread_desc_,
make_tuple(
Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
}
else
{
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 2) % MRepeat>{},
I0,
I0,
Number<k0 * KGroup + kg0>{},
I0,
I0),
a_block_buf.At(I0),
a_thread_desc_,
make_tuple(
Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
static_ford<Sequence<KRepeat, KGroup>>{}([&](auto kk) {
constexpr auto k0 = Number<kk[Number<0>{}]>{};
constexpr auto kg0 = Number<kk[Number<1>{}]>{};
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 2) % MRepeat>{},
I0,
I0,
Number<k0 * KGroup + kg0>{},
I0,
I0),
a_block_buf.At(I0),
a_thread_desc_,
make_tuple(
Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
}
});
@@ -947,22 +946,22 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v3<
if constexpr(m0.value < (MRepeat - 2))
{
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(
Number<m0 + 2>{}, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf.At(I1),
a_thread_desc_,
make_tuple(Number<(m0 + 2 + HotloopLocalBufSwitch) % 2>{},
I0,
I0,
k0,
I0,
Number<kg0 * A_K1>{}),
a_thread_buf);
});
static_ford<Sequence<KRepeat, KGroup>>{}([&](auto kk) {
constexpr auto k0 = Number<kk[Number<0>{}]>{};
constexpr auto kg0 = Number<kk[Number<1>{}]>{};
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(
Number<m0 + 2>{}, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf.At(I1),
a_thread_desc_,
make_tuple(Number<(m0 + 2 + HotloopLocalBufSwitch) % 2>{},
I0,
I0,
k0,
I0,
Number<kg0 * A_K1>{}),
a_thread_buf);
});
}
});
@@ -1023,18 +1022,18 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v3<
if constexpr(m0.value < (MRepeat - 2))
{
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(
Number<m0 + 2>{}, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf.At(I0),
a_thread_desc_,
make_tuple(
Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
static_ford<Sequence<KRepeat, KGroup>>{}([&](auto kk) {
constexpr auto k0 = Number<kk[Number<0>{}]>{};
constexpr auto kg0 = Number<kk[Number<1>{}]>{};
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(
Number<m0 + 2>{}, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf.At(I0),
a_thread_desc_,
make_tuple(
Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
}
});

View File

@@ -324,133 +324,128 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v1<BlockGemmPipelineScheduler::In
static_for<0, KRepeat, 1>{}([&](auto k) {
constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize *
(APackedSize * KPack / xdlops_gemm.K1PerXdlops);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}(
[&](auto chunk) {
constexpr auto a_k_step_chunk =
k_step +
chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k,
make_tuple(Number<m0 / MXdlPack>{},
I0,
Number<m0 % MXdlPack>{},
I0,
Number<a_k_step_chunk>{}),
a_block_buf,
a_thread_desc_,
make_tuple(Number<m0 / MXdlPack>{},
I0,
Number<m0 % MXdlPack>{},
k,
Number<chunk * KThreadChunk>{}),
a_thread_buf);
});
});
static_for<0, NRepeat, 1>{}([&](auto n0) {
// read block data in chunks to assemble correct thread vectors
static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}(
[&](auto chunk) {
constexpr auto b_k_step_chunk =
k_step +
chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k,
make_tuple(Number<n0 / NXdlPack>{},
I0,
Number<n0 % NXdlPack>{},
I0,
Number<b_k_step_chunk>{}),
b_block_buf,
b_thread_desc_,
make_tuple(Number<n0 / NXdlPack>{},
I0,
Number<n0 % NXdlPack>{},
k,
Number<chunk * KThreadChunk>{}),
b_thread_buf);
});
});
static_ford<Sequence<MRepeat,
xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk)>>{}(
[&](auto mc) {
constexpr auto m0 = Number<mc[Number<0>{}]>{};
constexpr auto chunk = Number<mc[Number<1>{}]>{};
constexpr auto a_k_step_chunk =
k_step +
chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k,
make_tuple(Number<m0 / MXdlPack>{},
I0,
Number<m0 % MXdlPack>{},
I0,
Number<a_k_step_chunk>{}),
a_block_buf,
a_thread_desc_,
make_tuple(Number<m0 / MXdlPack>{},
I0,
Number<m0 % MXdlPack>{},
k,
Number<chunk * KThreadChunk>{}),
a_thread_buf);
});
// read block data in chunks to assemble correct thread vectors
static_ford<Sequence<NRepeat,
xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk)>>{}(
[&](auto nc) {
constexpr auto n0 = Number<nc[Number<0>{}]>{};
constexpr auto chunk = Number<nc[Number<1>{}]>{};
constexpr auto b_k_step_chunk =
k_step +
chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k,
make_tuple(Number<n0 / NXdlPack>{},
I0,
Number<n0 % NXdlPack>{},
I0,
Number<b_k_step_chunk>{}),
b_block_buf,
b_thread_desc_,
make_tuple(Number<n0 / NXdlPack>{},
I0,
Number<n0 % NXdlPack>{},
k,
Number<chunk * KThreadChunk>{}),
b_thread_buf);
});
});
static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
constexpr index_t b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
static_ford<Sequence<MRepeat / MXdlPack, NRepeat / NXdlPack, KRepeat / KXdlPack>>{}(
[&](auto mnk) {
constexpr auto m0 = mnk[Number<0>{}];
constexpr auto n0 = mnk[Number<1>{}];
constexpr auto k0 = mnk[Number<2>{}];
static_assert(0 < ScalesPerXdlopsRunPerThread,
"Must have at least one scale per Xdlops "
"per Thread.");
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
constexpr index_t b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
vector_type<AScaleDataType, a_scale_thread_vec_size> a_scale_thread_vec;
vector_type<BScaleDataType, b_scale_thread_vec_size> b_scale_thread_vec;
static_assert(0 < ScalesPerXdlopsRunPerThread,
"Must have at least one scale per Xdlops "
"per Thread.");
// Pack scale_thread_buf into scale_thread_vec
static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) {
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
a_scale_thread_buf[Number<a_scale_offset + s>{}];
vector_type<AScaleDataType, a_scale_thread_vec_size> a_scale_thread_vec;
vector_type<BScaleDataType, b_scale_thread_vec_size> b_scale_thread_vec;
// Pack scale_thread_buf into scale_thread_vec
static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) {
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
a_scale_thread_buf[Number<a_scale_offset + s>{}];
});
static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) {
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
b_scale_thread_buf[Number<b_scale_offset + s>{}];
});
static_ford<Sequence<KXdlPack, MXdlPack, NXdlPack>>{}([&](auto kmn_xdl) {
constexpr auto ikxdl = Number<kmn_xdl[Number<0>{}]>{};
constexpr auto imxdl = Number<kmn_xdl[Number<1>{}]>{};
constexpr auto inxdl = Number<kmn_xdl[Number<2>{}]>{};
constexpr auto kxdl = ikxdl + k0 * KXdlPack;
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, imxdl, kxdl, ik))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, inxdl, kxdl, ik))>{}];
});
static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) {
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
b_scale_thread_buf[Number<b_scale_offset + s>{}];
});
using mfma_input_type_a =
typename vector_type<ComputeTypeA,
xdlops_gemm.K1PerXdlops / APackedSize>::type;
static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
constexpr auto kxdl = ikxdl + k0 * KXdlPack;
using mfma_input_type_b =
typename vector_type<ComputeTypeB,
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
using mfma_scale_input_type_a =
typename vector_type<AScaleDataType, a_scale_thread_vec_size>::type;
using mfma_scale_input_type_b =
typename vector_type<BScaleDataType, b_scale_thread_vec_size>::type;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, imxdl, kxdl, ik))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, inxdl, kxdl, ik))>{}];
});
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, imxdl, inxdl, 0));
using mfma_input_type_a =
typename vector_type<ComputeTypeA,
xdlops_gemm.K1PerXdlops /
APackedSize>::type;
using mfma_input_type_b =
typename vector_type<ComputeTypeB,
xdlops_gemm.K1PerXdlops /
BPackedSize>::type;
using mfma_scale_input_type_a =
typename vector_type<AScaleDataType,
a_scale_thread_vec_size>::type;
using mfma_scale_input_type_b =
typename vector_type<BScaleDataType,
b_scale_thread_vec_size>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, imxdl, inxdl, 0));
// MFMA accumulation
xdlops_gemm.template Run<ikxdl * MXdlPack + imxdl,
ikxdl * NXdlPack + inxdl>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec
.template AsType<mfma_scale_input_type_a>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec
.template AsType<mfma_scale_input_type_b>(),
c_thread_buf.GetVectorTypeReference(
Number<c_offset>{}));
});
});
});
// MFMA accumulation
xdlops_gemm
.template Run<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec.template AsType<mfma_scale_input_type_a>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec.template AsType<mfma_scale_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
// Prefetch a_scales
static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
@@ -510,132 +505,126 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v1<BlockGemmPipelineScheduler::In
static_for<0, KRepeat, 1>{}([&](auto k) {
constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize *
(APackedSize * KPack / xdlops_gemm.K1PerXdlops);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}(
[&](auto chunk) {
constexpr auto a_k_step_chunk =
k_step +
chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k,
make_tuple(Number<m0 / MXdlPack>{},
I0,
Number<m0 % MXdlPack>{},
I0,
Number<a_k_step_chunk>{}),
a_block_buf,
a_thread_desc_,
make_tuple(Number<m0 / MXdlPack>{},
I0,
Number<m0 % MXdlPack>{},
k,
Number<chunk * KThreadChunk>{}),
a_thread_buf);
});
});
static_for<0, NRepeat, 1>{}([&](auto n0) {
// read block data in chunks to assemble correct thread vectors
static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}(
[&](auto chunk) {
constexpr auto b_k_step_chunk =
k_step +
chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k,
make_tuple(Number<n0 / NXdlPack>{},
I0,
Number<n0 % NXdlPack>{},
I0,
Number<b_k_step_chunk>{}),
b_block_buf,
b_thread_desc_,
make_tuple(Number<n0 / NXdlPack>{},
I0,
Number<n0 % NXdlPack>{},
k,
Number<chunk * KThreadChunk>{}),
b_thread_buf);
});
});
static_ford<
Sequence<MRepeat, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk)>>{}(
[&](auto mc) {
constexpr auto m0 = Number<mc[Number<0>{}]>{};
constexpr auto chunk = Number<mc[Number<1>{}]>{};
constexpr auto a_k_step_chunk =
k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k,
make_tuple(Number<m0 / MXdlPack>{},
I0,
Number<m0 % MXdlPack>{},
I0,
Number<a_k_step_chunk>{}),
a_block_buf,
a_thread_desc_,
make_tuple(Number<m0 / MXdlPack>{},
I0,
Number<m0 % MXdlPack>{},
k,
Number<chunk * KThreadChunk>{}),
a_thread_buf);
});
// read block data in chunks to assemble correct thread vectors
static_ford<
Sequence<NRepeat, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk)>>{}(
[&](auto nc) {
constexpr auto n0 = Number<nc[Number<0>{}]>{};
constexpr auto chunk = Number<nc[Number<1>{}]>{};
constexpr auto b_k_step_chunk =
k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k,
make_tuple(Number<n0 / NXdlPack>{},
I0,
Number<n0 % NXdlPack>{},
I0,
Number<b_k_step_chunk>{}),
b_block_buf,
b_thread_desc_,
make_tuple(Number<n0 / NXdlPack>{},
I0,
Number<n0 % NXdlPack>{},
k,
Number<chunk * KThreadChunk>{}),
b_thread_buf);
});
});
static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
constexpr index_t b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
static_ford<Sequence<MRepeat / MXdlPack, NRepeat / NXdlPack, KRepeat / KXdlPack>>{}(
[&](auto mnk) {
constexpr auto m0 = mnk[Number<0>{}];
constexpr auto n0 = mnk[Number<1>{}];
constexpr auto k0 = mnk[Number<2>{}];
static_assert(0 < ScalesPerXdlopsRunPerThread,
"Must have at least one scale per Xdlops "
"per Thread.");
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
constexpr index_t b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
vector_type<AScaleDataType, a_scale_thread_vec_size> a_scale_thread_vec;
vector_type<BScaleDataType, b_scale_thread_vec_size> b_scale_thread_vec;
static_assert(0 < ScalesPerXdlopsRunPerThread,
"Must have at least one scale per Xdlops "
"per Thread.");
// Pack scale_thread_buf into scale_thread_vec
static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) {
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
a_scale_thread_buf[Number<a_scale_offset + s>{}];
vector_type<AScaleDataType, a_scale_thread_vec_size> a_scale_thread_vec;
vector_type<BScaleDataType, b_scale_thread_vec_size> b_scale_thread_vec;
// Pack scale_thread_buf into scale_thread_vec
static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) {
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
a_scale_thread_buf[Number<a_scale_offset + s>{}];
});
static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) {
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
b_scale_thread_buf[Number<b_scale_offset + s>{}];
});
static_ford<Sequence<KXdlPack, MXdlPack, NXdlPack>>{}([&](auto kmn_xdl) {
constexpr auto ikxdl = Number<kmn_xdl[Number<0>{}]>{};
constexpr auto imxdl = Number<kmn_xdl[Number<1>{}]>{};
constexpr auto inxdl = Number<kmn_xdl[Number<2>{}]>{};
constexpr auto kxdl = ikxdl + k0 * KXdlPack;
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, imxdl, kxdl, ik))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, inxdl, kxdl, ik))>{}];
});
static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) {
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
b_scale_thread_buf[Number<b_scale_offset + s>{}];
});
using mfma_input_type_a =
typename vector_type<ComputeTypeA,
xdlops_gemm.K1PerXdlops / APackedSize>::type;
static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
constexpr auto kxdl = ikxdl + k0 * KXdlPack;
using mfma_input_type_b =
typename vector_type<ComputeTypeB,
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
using mfma_scale_input_type_a =
typename vector_type<AScaleDataType, a_scale_thread_vec_size>::type;
using mfma_scale_input_type_b =
typename vector_type<BScaleDataType, b_scale_thread_vec_size>::type;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, imxdl, kxdl, ik))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, inxdl, kxdl, ik))>{}];
});
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, imxdl, inxdl, 0));
using mfma_input_type_a =
typename vector_type<ComputeTypeA,
xdlops_gemm.K1PerXdlops /
APackedSize>::type;
using mfma_input_type_b =
typename vector_type<ComputeTypeB,
xdlops_gemm.K1PerXdlops /
BPackedSize>::type;
using mfma_scale_input_type_a =
typename vector_type<AScaleDataType,
a_scale_thread_vec_size>::type;
using mfma_scale_input_type_b =
typename vector_type<BScaleDataType,
b_scale_thread_vec_size>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, imxdl, inxdl, 0));
// MFMA accumulation
xdlops_gemm.template Run<ikxdl * MXdlPack + imxdl,
ikxdl * NXdlPack + inxdl>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec
.template AsType<mfma_scale_input_type_a>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec
.template AsType<mfma_scale_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
// MFMA accumulation
xdlops_gemm
.template Run<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec.template AsType<mfma_scale_input_type_a>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec.template AsType<mfma_scale_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
}
}

View File

@@ -492,49 +492,51 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3<BlockGemmPipelineScheduler::In
static_for<0, KRepeat, 1>{}([&](auto k) {
constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize *
(APackedSize * KPack / xdlops_gemm.K1PerXdlops);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}(
[&](auto chunk) {
constexpr auto a_k_step_chunk =
k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k,
make_tuple(Number<m0 / MXdlPack>{},
I0,
Number<m0 % MXdlPack>{},
I0,
Number<a_k_step_chunk>{}),
a_block_buf,
a_thread_desc_,
make_tuple(Number<m0 / MXdlPack>{},
I0,
Number<m0 % MXdlPack>{},
k,
Number<chunk * KThreadChunk>{}),
a_thread_buf);
});
});
static_for<0, NRepeat, 1>{}([&](auto n0) {
// read block data in chunks to assemble correct thread vectors
static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}(
[&](auto chunk) {
constexpr auto b_k_step_chunk =
k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k,
make_tuple(Number<n0 / NXdlPack>{},
I0,
Number<n0 % NXdlPack>{},
I0,
Number<b_k_step_chunk>{}),
b_block_buf,
b_thread_desc_,
make_tuple(Number<n0 / NXdlPack>{},
I0,
Number<n0 % NXdlPack>{},
k,
Number<chunk * KThreadChunk>{}),
b_thread_buf);
});
});
static_ford<
Sequence<MRepeat, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk)>>{}(
[&](auto mc) {
constexpr auto m0 = Number<mc[Number<0>{}]>{};
constexpr auto chunk = Number<mc[Number<1>{}]>{};
constexpr auto a_k_step_chunk =
k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k,
make_tuple(Number<m0 / MXdlPack>{},
I0,
Number<m0 % MXdlPack>{},
I0,
Number<a_k_step_chunk>{}),
a_block_buf,
a_thread_desc_,
make_tuple(Number<m0 / MXdlPack>{},
I0,
Number<m0 % MXdlPack>{},
k,
Number<chunk * KThreadChunk>{}),
a_thread_buf);
});
// read block data in chunks to assemble correct thread vectors
static_ford<
Sequence<NRepeat, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk)>>{}(
[&](auto nc) {
constexpr auto n0 = Number<nc[Number<0>{}]>{};
constexpr auto chunk = Number<nc[Number<1>{}]>{};
constexpr auto b_k_step_chunk =
k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k,
make_tuple(Number<n0 / NXdlPack>{},
I0,
Number<n0 % NXdlPack>{},
I0,
Number<b_k_step_chunk>{}),
b_block_buf,
b_thread_desc_,
make_tuple(Number<n0 / NXdlPack>{},
I0,
Number<n0 % NXdlPack>{},
k,
Number<chunk * KThreadChunk>{}),
b_thread_buf);
});
});
// Initialize C
@@ -603,91 +605,78 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3<BlockGemmPipelineScheduler::In
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
constexpr index_t b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
static_ford<Sequence<MRepeat / MXdlPack,
NRepeat / NXdlPack,
KRepeat / KXdlPack>>{}([&](auto mnk) {
constexpr auto m0 = mnk[Number<0>{}];
constexpr auto n0 = mnk[Number<1>{}];
constexpr auto k0 = mnk[Number<2>{}];
static_assert(0 < ScalesPerXdlopsRunPerThread,
"Must have at least one scale per Xdlops "
"per Thread.");
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
constexpr index_t b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
vector_type<AScaleDataType, a_scale_thread_vec_size>
a_scale_thread_vec;
vector_type<BScaleDataType, b_scale_thread_vec_size>
b_scale_thread_vec;
static_assert(0 < ScalesPerXdlopsRunPerThread,
"Must have at least one scale per Xdlops "
"per Thread.");
// Pack scale_thread_buf into scale_thread_vec
static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) {
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
a_scale_thread_bufs(
scale_comp_buf)[Number<a_scale_offset + s>{}];
});
vector_type<AScaleDataType, a_scale_thread_vec_size> a_scale_thread_vec;
vector_type<BScaleDataType, b_scale_thread_vec_size> b_scale_thread_vec;
static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) {
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
b_scale_thread_bufs(
scale_comp_buf)[Number<b_scale_offset + s>{}];
});
// Pack scale_thread_buf into scale_thread_vec
static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) {
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
a_scale_thread_bufs(scale_comp_buf)[Number<a_scale_offset + s>{}];
});
static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
constexpr auto kxdl = ikxdl + k0 * KXdlPack;
static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) {
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
b_scale_thread_bufs(scale_comp_buf)[Number<b_scale_offset + s>{}];
});
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_ford<Sequence<KXdlPack, MXdlPack, NXdlPack>>{}([&](auto kmn_xdl) {
constexpr auto ikxdl = Number<kmn_xdl[Number<0>{}]>{};
constexpr auto imxdl = Number<kmn_xdl[Number<1>{}]>{};
constexpr auto inxdl = Number<kmn_xdl[Number<2>{}]>{};
constexpr auto kxdl = ikxdl + k0 * KXdlPack;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(
ik) = a_thread_buf
[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, imxdl, kxdl, ik))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(
ik) = b_thread_buf
[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, inxdl, kxdl, ik))>{}];
});
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
using mfma_input_type_a =
typename vector_type<ComputeTypeA,
xdlops_gemm.K1PerXdlops /
APackedSize>::type;
using mfma_input_type_b =
typename vector_type<ComputeTypeB,
xdlops_gemm.K1PerXdlops /
BPackedSize>::type;
using mfma_scale_input_type_a =
typename vector_type<AScaleDataType,
a_scale_thread_vec_size>::type;
using mfma_scale_input_type_b =
typename vector_type<BScaleDataType,
b_scale_thread_vec_size>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, imxdl, inxdl, 0));
// MFMA accumulation
xdlops_gemm.template Run<ikxdl * MXdlPack + imxdl,
ikxdl * NXdlPack + inxdl>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec
.template AsType<mfma_scale_input_type_a>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec
.template AsType<mfma_scale_input_type_b>(),
c_thread_buf.GetVectorTypeReference(
Number<c_offset>{}));
});
});
});
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, imxdl, kxdl, ik))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, inxdl, kxdl, ik))>{}];
});
using mfma_input_type_a =
typename vector_type<ComputeTypeA,
xdlops_gemm.K1PerXdlops / APackedSize>::type;
using mfma_input_type_b =
typename vector_type<ComputeTypeB,
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
using mfma_scale_input_type_a =
typename vector_type<AScaleDataType, a_scale_thread_vec_size>::type;
using mfma_scale_input_type_b =
typename vector_type<BScaleDataType, b_scale_thread_vec_size>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, imxdl, inxdl, 0));
// MFMA accumulation
xdlops_gemm
.template Run<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec.template AsType<mfma_scale_input_type_a>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec.template AsType<mfma_scale_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
@@ -805,299 +794,281 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3<BlockGemmPipelineScheduler::In
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
constexpr index_t b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
static_ford<Sequence<MRepeat / MXdlPack, NRepeat / NXdlPack, KRepeat / KXdlPack>>{}(
[&](auto mnk) {
constexpr auto m0 = mnk[Number<0>{}];
constexpr auto n0 = mnk[Number<1>{}];
constexpr auto k0 = mnk[Number<2>{}];
static_assert(0 < ScalesPerXdlopsRunPerThread,
"Must have at least one scale per Xdlops "
"per Thread.");
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
constexpr index_t b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
vector_type<AScaleDataType, a_scale_thread_vec_size> a_scale_thread_vec;
vector_type<BScaleDataType, b_scale_thread_vec_size> b_scale_thread_vec;
static_assert(0 < ScalesPerXdlopsRunPerThread,
"Must have at least one scale per Xdlops "
"per Thread.");
// Pack scale_thread_buf into scale_thread_vec
static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) {
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
a_scale_thread_bufs(I0)[Number<a_scale_offset + s>{}];
vector_type<AScaleDataType, a_scale_thread_vec_size> a_scale_thread_vec;
vector_type<BScaleDataType, b_scale_thread_vec_size> b_scale_thread_vec;
// Pack scale_thread_buf into scale_thread_vec
static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) {
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
a_scale_thread_bufs(I0)[Number<a_scale_offset + s>{}];
});
static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) {
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
b_scale_thread_bufs(I0)[Number<b_scale_offset + s>{}];
});
static_ford<Sequence<KXdlPack, MXdlPack, NXdlPack>>{}([&](auto kmn_xdl) {
constexpr auto ikxdl = Number<kmn_xdl[Number<0>{}]>{};
constexpr auto imxdl = Number<kmn_xdl[Number<1>{}]>{};
constexpr auto inxdl = Number<kmn_xdl[Number<2>{}]>{};
constexpr auto kxdl = ikxdl + k0 * KXdlPack;
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, imxdl, kxdl, ik))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, inxdl, kxdl, ik))>{}];
});
static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) {
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
b_scale_thread_bufs(I0)[Number<b_scale_offset + s>{}];
});
using mfma_input_type_a =
typename vector_type<ComputeTypeA,
xdlops_gemm.K1PerXdlops / APackedSize>::type;
static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
constexpr auto kxdl = ikxdl + k0 * KXdlPack;
using mfma_input_type_b =
typename vector_type<ComputeTypeB,
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
using mfma_scale_input_type_a =
typename vector_type<AScaleDataType, a_scale_thread_vec_size>::type;
using mfma_scale_input_type_b =
typename vector_type<BScaleDataType, b_scale_thread_vec_size>::type;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, imxdl, kxdl, ik))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, inxdl, kxdl, ik))>{}];
});
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, imxdl, inxdl, 0));
using mfma_input_type_a =
typename vector_type<ComputeTypeA,
xdlops_gemm.K1PerXdlops /
APackedSize>::type;
using mfma_input_type_b =
typename vector_type<ComputeTypeB,
xdlops_gemm.K1PerXdlops /
BPackedSize>::type;
using mfma_scale_input_type_a =
typename vector_type<AScaleDataType,
a_scale_thread_vec_size>::type;
using mfma_scale_input_type_b =
typename vector_type<BScaleDataType,
b_scale_thread_vec_size>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, imxdl, inxdl, 0));
// MFMA accumulation
xdlops_gemm.template Run<ikxdl * MXdlPack + imxdl,
ikxdl * NXdlPack + inxdl>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec
.template AsType<mfma_scale_input_type_a>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec
.template AsType<mfma_scale_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
// MFMA accumulation
xdlops_gemm
.template Run<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec.template AsType<mfma_scale_input_type_a>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec.template AsType<mfma_scale_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) {
constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize *
(APackedSize * KPack / xdlops_gemm.K1PerXdlops);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}(
[&](auto chunk) {
constexpr auto a_k_step_chunk =
k_step +
chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k,
make_tuple(Number<m0 / MXdlPack>{},
I0,
Number<m0 % MXdlPack>{},
I0,
Number<a_k_step_chunk>{}),
a_block_buf,
a_thread_desc_,
make_tuple(Number<m0 / MXdlPack>{},
I0,
Number<m0 % MXdlPack>{},
k,
Number<chunk * KThreadChunk>{}),
a_thread_buf);
});
});
static_for<0, NRepeat, 1>{}([&](auto n0) {
// read block data in chunks to assemble correct thread vectors
static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}(
[&](auto chunk) {
constexpr auto b_k_step_chunk =
k_step +
chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k,
make_tuple(Number<n0 / NXdlPack>{},
I0,
Number<n0 % NXdlPack>{},
I0,
Number<b_k_step_chunk>{}),
b_block_buf,
b_thread_desc_,
make_tuple(Number<n0 / NXdlPack>{},
I0,
Number<n0 % NXdlPack>{},
k,
Number<chunk * KThreadChunk>{}),
b_thread_buf);
});
});
static_ford<
Sequence<MRepeat, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk)>>{}(
[&](auto mc) {
constexpr auto m0 = Number<mc[Number<0>{}]>{};
constexpr auto chunk = Number<mc[Number<1>{}]>{};
constexpr auto a_k_step_chunk =
k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k,
make_tuple(Number<m0 / MXdlPack>{},
I0,
Number<m0 % MXdlPack>{},
I0,
Number<a_k_step_chunk>{}),
a_block_buf,
a_thread_desc_,
make_tuple(Number<m0 / MXdlPack>{},
I0,
Number<m0 % MXdlPack>{},
k,
Number<chunk * KThreadChunk>{}),
a_thread_buf);
});
// read block data in chunks to assemble correct thread vectors
static_ford<
Sequence<NRepeat, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk)>>{}(
[&](auto nc) {
constexpr auto n0 = Number<nc[Number<0>{}]>{};
constexpr auto chunk = Number<nc[Number<1>{}]>{};
constexpr auto b_k_step_chunk =
k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k,
make_tuple(Number<n0 / NXdlPack>{},
I0,
Number<n0 % NXdlPack>{},
I0,
Number<b_k_step_chunk>{}),
b_block_buf,
b_thread_desc_,
make_tuple(Number<n0 / NXdlPack>{},
I0,
Number<n0 % NXdlPack>{},
k,
Number<chunk * KThreadChunk>{}),
b_thread_buf);
});
});
static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
constexpr index_t b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
static_ford<Sequence<MRepeat / MXdlPack, NRepeat / NXdlPack, KRepeat / KXdlPack>>{}(
[&](auto mnk) {
constexpr auto m0 = mnk[Number<0>{}];
constexpr auto n0 = mnk[Number<1>{}];
constexpr auto k0 = mnk[Number<2>{}];
static_assert(0 < ScalesPerXdlopsRunPerThread,
"Must have at least one scale per Xdlops "
"per Thread.");
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
constexpr index_t b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
vector_type<AScaleDataType, a_scale_thread_vec_size> a_scale_thread_vec;
vector_type<BScaleDataType, b_scale_thread_vec_size> b_scale_thread_vec;
static_assert(0 < ScalesPerXdlopsRunPerThread,
"Must have at least one scale per Xdlops "
"per Thread.");
// Pack scale_thread_buf into scale_thread_vec
static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) {
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
a_scale_thread_bufs(I1)[Number<a_scale_offset + s>{}];
vector_type<AScaleDataType, a_scale_thread_vec_size> a_scale_thread_vec;
vector_type<BScaleDataType, b_scale_thread_vec_size> b_scale_thread_vec;
// Pack scale_thread_buf into scale_thread_vec
static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) {
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
a_scale_thread_bufs(I1)[Number<a_scale_offset + s>{}];
});
static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) {
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
b_scale_thread_bufs(I1)[Number<b_scale_offset + s>{}];
});
static_ford<Sequence<KXdlPack, MXdlPack, NXdlPack>>{}([&](auto kmn_xdl) {
constexpr auto ikxdl = Number<kmn_xdl[Number<0>{}]>{};
constexpr auto imxdl = Number<kmn_xdl[Number<1>{}]>{};
constexpr auto inxdl = Number<kmn_xdl[Number<2>{}]>{};
constexpr auto kxdl = ikxdl + k0 * KXdlPack;
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, imxdl, kxdl, ik))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, inxdl, kxdl, ik))>{}];
});
static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) {
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
b_scale_thread_bufs(I1)[Number<b_scale_offset + s>{}];
});
using mfma_input_type_a =
typename vector_type<ComputeTypeA,
xdlops_gemm.K1PerXdlops / APackedSize>::type;
static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
constexpr auto kxdl = ikxdl + k0 * KXdlPack;
using mfma_input_type_b =
typename vector_type<ComputeTypeB,
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
using mfma_scale_input_type_a =
typename vector_type<AScaleDataType, a_scale_thread_vec_size>::type;
using mfma_scale_input_type_b =
typename vector_type<BScaleDataType, b_scale_thread_vec_size>::type;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, imxdl, kxdl, ik))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, inxdl, kxdl, ik))>{}];
});
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, imxdl, inxdl, 0));
using mfma_input_type_a =
typename vector_type<ComputeTypeA,
xdlops_gemm.K1PerXdlops /
APackedSize>::type;
using mfma_input_type_b =
typename vector_type<ComputeTypeB,
xdlops_gemm.K1PerXdlops /
BPackedSize>::type;
using mfma_scale_input_type_a =
typename vector_type<AScaleDataType,
a_scale_thread_vec_size>::type;
using mfma_scale_input_type_b =
typename vector_type<BScaleDataType,
b_scale_thread_vec_size>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, imxdl, inxdl, 0));
// MFMA accumulation
xdlops_gemm.template Run<ikxdl * MXdlPack + imxdl,
ikxdl * NXdlPack + inxdl>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec
.template AsType<mfma_scale_input_type_a>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec
.template AsType<mfma_scale_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
// MFMA accumulation
xdlops_gemm
.template Run<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec.template AsType<mfma_scale_input_type_a>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec.template AsType<mfma_scale_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
}
else if constexpr(TailNum == TailNumber::Odd)
{
static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
constexpr index_t b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
static_ford<Sequence<MRepeat / MXdlPack, NRepeat / NXdlPack, KRepeat / KXdlPack>>{}(
[&](auto mnk) {
constexpr auto m0 = mnk[Number<0>{}];
constexpr auto n0 = mnk[Number<1>{}];
constexpr auto k0 = mnk[Number<2>{}];
static_assert(0 < ScalesPerXdlopsRunPerThread,
"Must have at least one scale per Xdlops "
"per Thread.");
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
constexpr index_t b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
vector_type<AScaleDataType, a_scale_thread_vec_size> a_scale_thread_vec;
vector_type<BScaleDataType, b_scale_thread_vec_size> b_scale_thread_vec;
static_assert(0 < ScalesPerXdlopsRunPerThread,
"Must have at least one scale per Xdlops "
"per Thread.");
// Pack scale_thread_buf into scale_thread_vec
static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) {
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
a_scale_thread_bufs(I0)[Number<a_scale_offset + s>{}];
vector_type<AScaleDataType, a_scale_thread_vec_size> a_scale_thread_vec;
vector_type<BScaleDataType, b_scale_thread_vec_size> b_scale_thread_vec;
// Pack scale_thread_buf into scale_thread_vec
static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) {
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
a_scale_thread_bufs(I0)[Number<a_scale_offset + s>{}];
});
static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) {
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
b_scale_thread_bufs(I0)[Number<b_scale_offset + s>{}];
});
static_ford<Sequence<KXdlPack, MXdlPack, NXdlPack>>{}([&](auto kmn_xdl) {
constexpr auto ikxdl = Number<kmn_xdl[Number<0>{}]>{};
constexpr auto imxdl = Number<kmn_xdl[Number<1>{}]>{};
constexpr auto inxdl = Number<kmn_xdl[Number<2>{}]>{};
constexpr auto kxdl = ikxdl + k0 * KXdlPack;
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, imxdl, kxdl, ik))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, inxdl, kxdl, ik))>{}];
});
static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) {
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
b_scale_thread_bufs(I0)[Number<b_scale_offset + s>{}];
});
using mfma_input_type_a =
typename vector_type<ComputeTypeA,
xdlops_gemm.K1PerXdlops / APackedSize>::type;
static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
constexpr auto kxdl = ikxdl + k0 * KXdlPack;
using mfma_input_type_b =
typename vector_type<ComputeTypeB,
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
using mfma_scale_input_type_a =
typename vector_type<AScaleDataType, a_scale_thread_vec_size>::type;
using mfma_scale_input_type_b =
typename vector_type<BScaleDataType, b_scale_thread_vec_size>::type;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, imxdl, kxdl, ik))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, inxdl, kxdl, ik))>{}];
});
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, imxdl, inxdl, 0));
using mfma_input_type_a =
typename vector_type<ComputeTypeA,
xdlops_gemm.K1PerXdlops /
APackedSize>::type;
using mfma_input_type_b =
typename vector_type<ComputeTypeB,
xdlops_gemm.K1PerXdlops /
BPackedSize>::type;
using mfma_scale_input_type_a =
typename vector_type<AScaleDataType,
a_scale_thread_vec_size>::type;
using mfma_scale_input_type_b =
typename vector_type<BScaleDataType,
b_scale_thread_vec_size>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, imxdl, inxdl, 0));
// MFMA accumulation
xdlops_gemm.template Run<ikxdl * MXdlPack + imxdl,
ikxdl * NXdlPack + inxdl>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec
.template AsType<mfma_scale_input_type_a>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec
.template AsType<mfma_scale_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
// MFMA accumulation
xdlops_gemm
.template Run<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec.template AsType<mfma_scale_input_type_a>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec.template AsType<mfma_scale_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
}
}

View File

@@ -459,49 +459,51 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3<BlockGemmPipelineScheduler::In
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) {
constexpr auto k_step = k * xdlops_gemm.KPerXdlops * KPack / xdlops_gemm.K1PerXdlops;
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}(
[&](auto chunk) {
constexpr auto a_k_step_chunk =
k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k,
make_tuple(Number<m0 / MXdlPack>{},
I0,
Number<m0 % MXdlPack>{},
I0,
Number<a_k_step_chunk>{}),
a_block_bufs(I0),
a_thread_desc_,
make_tuple(Number<m0 / MXdlPack>{},
I0,
Number<m0 % MXdlPack>{},
k,
Number<chunk * KThreadChunk>{}),
a_thread_buf);
});
});
static_for<0, NRepeat, 1>{}([&](auto n0) {
// read block data in chunks to assemble correct thread vectors
static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}(
[&](auto chunk) {
constexpr auto b_k_step_chunk =
k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k,
make_tuple(Number<n0 / NXdlPack>{},
I0,
Number<n0 % NXdlPack>{},
I0,
Number<b_k_step_chunk>{}),
b_block_bufs(I0),
b_thread_desc_,
make_tuple(Number<n0 / NXdlPack>{},
I0,
Number<n0 % NXdlPack>{},
k,
Number<chunk * KThreadChunk>{}),
b_thread_buf);
});
});
static_ford<
Sequence<MRepeat, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk)>>{}(
[&](auto mc) {
constexpr auto m0 = Number<mc[Number<0>{}]>{};
constexpr auto chunk = Number<mc[Number<1>{}]>{};
constexpr auto a_k_step_chunk =
k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k,
make_tuple(Number<m0 / MXdlPack>{},
I0,
Number<m0 % MXdlPack>{},
I0,
Number<a_k_step_chunk>{}),
a_block_bufs(I0),
a_thread_desc_,
make_tuple(Number<m0 / MXdlPack>{},
I0,
Number<m0 % MXdlPack>{},
k,
Number<chunk * KThreadChunk>{}),
a_thread_buf);
});
// read block data in chunks to assemble correct thread vectors
static_ford<
Sequence<NRepeat, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk)>>{}(
[&](auto nc) {
constexpr auto n0 = Number<nc[Number<0>{}]>{};
constexpr auto chunk = Number<nc[Number<1>{}]>{};
constexpr auto b_k_step_chunk =
k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k,
make_tuple(Number<n0 / NXdlPack>{},
I0,
Number<n0 % NXdlPack>{},
I0,
Number<b_k_step_chunk>{}),
b_block_bufs(I0),
b_thread_desc_,
make_tuple(Number<n0 / NXdlPack>{},
I0,
Number<n0 % NXdlPack>{},
k,
Number<chunk * KThreadChunk>{}),
b_thread_buf);
});
});
// Global prefetch 2
@@ -577,91 +579,85 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3<BlockGemmPipelineScheduler::In
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
constexpr index_t b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
static_ford<
Sequence<MRepeat / MXdlPack, NRepeat / NXdlPack, KRepeat / KXdlPack>>{}(
[&](auto mnk) {
constexpr auto m0 = mnk[Number<0>{}];
constexpr auto n0 = mnk[Number<1>{}];
constexpr auto k0 = mnk[Number<2>{}];
static_assert(0 < ScalesPerXdlopsRunPerThread,
"Must have at least one scale per Xdlops "
"per Thread.");
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
constexpr index_t b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
vector_type<AScaleDataType, a_scale_thread_vec_size>
a_scale_thread_vec;
vector_type<BScaleDataType, b_scale_thread_vec_size>
b_scale_thread_vec;
static_assert(0 < ScalesPerXdlopsRunPerThread,
"Must have at least one scale per Xdlops "
"per Thread.");
// Pack scale_thread_buf into scale_thread_vec
static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) {
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
a_scale_thread_bufs(
scale_comp_buf)[Number<a_scale_offset + s>{}];
vector_type<AScaleDataType, a_scale_thread_vec_size> a_scale_thread_vec;
vector_type<BScaleDataType, b_scale_thread_vec_size> b_scale_thread_vec;
// Pack scale_thread_buf into scale_thread_vec
static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) {
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
a_scale_thread_bufs(
scale_comp_buf)[Number<a_scale_offset + s>{}];
});
static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) {
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
b_scale_thread_bufs(
scale_comp_buf)[Number<b_scale_offset + s>{}];
});
static_ford<Sequence<KXdlPack, MXdlPack, NXdlPack>>{}([&](auto
kmn_xdl) {
constexpr auto ikxdl = Number<kmn_xdl[Number<0>{}]>{};
constexpr auto imxdl = Number<kmn_xdl[Number<1>{}]>{};
constexpr auto inxdl = Number<kmn_xdl[Number<2>{}]>{};
constexpr auto kxdl = ikxdl + k0 * KXdlPack;
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, imxdl, kxdl, ik))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, inxdl, kxdl, ik))>{}];
});
static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) {
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
b_scale_thread_bufs(
scale_comp_buf)[Number<b_scale_offset + s>{}];
});
using mfma_input_type_a = typename vector_type< //
ComputeTypeA,
xdlops_gemm.K1PerXdlops / APackedSize>::type;
static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
constexpr auto kxdl = ikxdl + k0 * KXdlPack;
using mfma_input_type_b = typename vector_type< //
ComputeTypeB,
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
using mfma_scale_input_type_a = typename vector_type< //
AScaleDataType,
a_scale_thread_vec_size>::type;
using mfma_scale_input_type_b = typename vector_type< //
BScaleDataType,
b_scale_thread_vec_size>::type;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(
ik) = a_thread_buf
[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, imxdl, kxdl, ik))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(
ik) = b_thread_buf
[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, inxdl, kxdl, ik))>{}];
});
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, imxdl, inxdl, 0));
using mfma_input_type_a = typename vector_type< //
ComputeTypeA,
xdlops_gemm.K1PerXdlops / APackedSize>::type;
using mfma_input_type_b = typename vector_type< //
ComputeTypeB,
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
using mfma_scale_input_type_a = typename vector_type< //
AScaleDataType,
a_scale_thread_vec_size>::type;
using mfma_scale_input_type_b = typename vector_type< //
BScaleDataType,
b_scale_thread_vec_size>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, imxdl, inxdl, 0));
// MFMA accumulation
xdlops_gemm.template Run<ikxdl * MXdlPack + imxdl,
ikxdl * NXdlPack + inxdl>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec
.template AsType<mfma_scale_input_type_a>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec
.template AsType<mfma_scale_input_type_b>(),
c_thread_buf.GetVectorTypeReference(
Number<c_offset>{}));
});
});
});
// MFMA accumulation
xdlops_gemm.template Run<ikxdl * MXdlPack + imxdl,
ikxdl * NXdlPack + inxdl>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec.template AsType<mfma_scale_input_type_a>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec.template AsType<mfma_scale_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
// k indexes mapping to threads for 32x32x64:
// t0 : |0 --> 15 32 --> 47 | 64 --> 79 96 --> 111 | etc.
@@ -774,83 +770,81 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3<BlockGemmPipelineScheduler::In
b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0));
});
static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
constexpr index_t b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
static_ford<Sequence<MRepeat / MXdlPack, NRepeat / NXdlPack, KRepeat / KXdlPack>>{}(
[&](auto mnk) {
constexpr auto m0 = mnk[Number<0>{}];
constexpr auto n0 = mnk[Number<1>{}];
constexpr auto k0 = mnk[Number<2>{}];
static_assert(0 < ScalesPerXdlopsRunPerThread,
"Must have at least one scale per Xdlops "
"per Thread.");
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
constexpr index_t b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
vector_type<AScaleDataType, a_scale_thread_vec_size> a_scale_thread_vec;
vector_type<BScaleDataType, b_scale_thread_vec_size> b_scale_thread_vec;
static_assert(0 < ScalesPerXdlopsRunPerThread,
"Must have at least one scale per Xdlops "
"per Thread.");
// Pack scale_thread_buf into scale_thread_vec
static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) {
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
a_scale_thread_bufs(I0)[Number<a_scale_offset + s>{}];
vector_type<AScaleDataType, a_scale_thread_vec_size> a_scale_thread_vec;
vector_type<BScaleDataType, b_scale_thread_vec_size> b_scale_thread_vec;
// Pack scale_thread_buf into scale_thread_vec
static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) {
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
a_scale_thread_bufs(I0)[Number<a_scale_offset + s>{}];
});
static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) {
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
b_scale_thread_bufs(I0)[Number<b_scale_offset + s>{}];
});
static_ford<Sequence<KXdlPack, MXdlPack, NXdlPack>>{}([&](auto kmn_xdl) {
constexpr auto ikxdl = Number<kmn_xdl[Number<0>{}]>{};
constexpr auto imxdl = Number<kmn_xdl[Number<1>{}]>{};
constexpr auto inxdl = Number<kmn_xdl[Number<2>{}]>{};
constexpr auto kxdl = ikxdl + k0 * KXdlPack;
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, imxdl, kxdl, ik))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, inxdl, kxdl, ik))>{}];
});
static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) {
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
b_scale_thread_bufs(I0)[Number<b_scale_offset + s>{}];
});
using mfma_input_type_a = typename vector_type< //
ComputeTypeA,
xdlops_gemm.K1PerXdlops / APackedSize>::type;
static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
constexpr auto kxdl = ikxdl + k0 * KXdlPack;
using mfma_input_type_b = typename vector_type< //
ComputeTypeB,
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
using mfma_scale_input_type_a = typename vector_type< //
AScaleDataType,
a_scale_thread_vec_size>::type;
using mfma_scale_input_type_b = typename vector_type< //
BScaleDataType,
b_scale_thread_vec_size>::type;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, imxdl, kxdl, ik))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, inxdl, kxdl, ik))>{}];
});
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, imxdl, inxdl, 0));
using mfma_input_type_a = typename vector_type< //
ComputeTypeA,
xdlops_gemm.K1PerXdlops / APackedSize>::type;
using mfma_input_type_b = typename vector_type< //
ComputeTypeB,
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
using mfma_scale_input_type_a = typename vector_type< //
AScaleDataType,
a_scale_thread_vec_size>::type;
using mfma_scale_input_type_b = typename vector_type< //
BScaleDataType,
b_scale_thread_vec_size>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, imxdl, inxdl, 0));
// MFMA accumulation
xdlops_gemm.template Run<ikxdl * MXdlPack + imxdl,
ikxdl * NXdlPack + inxdl>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec
.template AsType<mfma_scale_input_type_a>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec
.template AsType<mfma_scale_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
// MFMA accumulation
xdlops_gemm
.template Run<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec.template AsType<mfma_scale_input_type_a>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec.template AsType<mfma_scale_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
__builtin_amdgcn_s_waitcnt(3952);
block_sync_lds();
@@ -858,210 +852,206 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3<BlockGemmPipelineScheduler::In
static_for<0, KRepeat, 1>{}([&](auto k) {
constexpr auto k_step =
k * xdlops_gemm.KPerXdlops * KPack / xdlops_gemm.K1PerXdlops;
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}(
[&](auto chunk) {
constexpr auto a_k_step_chunk =
k_step +
chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k,
make_tuple(Number<m0 / MXdlPack>{},
I0,
Number<m0 % MXdlPack>{},
I0,
Number<a_k_step_chunk>{}),
a_block_bufs(I1),
a_thread_desc_,
make_tuple(Number<m0 / MXdlPack>{},
I0,
Number<m0 % MXdlPack>{},
k,
Number<chunk * KThreadChunk>{}),
a_thread_buf);
});
});
static_for<0, NRepeat, 1>{}([&](auto n0) {
// read block data in chunks to assemble correct thread vectors
static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}(
[&](auto chunk) {
constexpr auto b_k_step_chunk =
k_step +
chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k,
make_tuple(Number<n0 / NXdlPack>{},
I0,
Number<n0 % NXdlPack>{},
I0,
Number<b_k_step_chunk>{}),
b_block_bufs(I1),
b_thread_desc_,
make_tuple(Number<n0 / NXdlPack>{},
I0,
Number<n0 % NXdlPack>{},
k,
Number<chunk * KThreadChunk>{}),
b_thread_buf);
});
});
static_ford<
Sequence<MRepeat, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk)>>{}(
[&](auto mc) {
constexpr auto m0 = Number<mc[Number<0>{}]>{};
constexpr auto chunk = Number<mc[Number<1>{}]>{};
constexpr auto a_k_step_chunk =
k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k,
make_tuple(Number<m0 / MXdlPack>{},
I0,
Number<m0 % MXdlPack>{},
I0,
Number<a_k_step_chunk>{}),
a_block_bufs(I1),
a_thread_desc_,
make_tuple(Number<m0 / MXdlPack>{},
I0,
Number<m0 % MXdlPack>{},
k,
Number<chunk * KThreadChunk>{}),
a_thread_buf);
});
// read block data in chunks to assemble correct thread vectors
static_ford<
Sequence<NRepeat, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk)>>{}(
[&](auto nc) {
constexpr auto n0 = Number<nc[Number<0>{}]>{};
constexpr auto chunk = Number<nc[Number<1>{}]>{};
constexpr auto b_k_step_chunk =
k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k,
make_tuple(Number<n0 / NXdlPack>{},
I0,
Number<n0 % NXdlPack>{},
I0,
Number<b_k_step_chunk>{}),
b_block_bufs(I1),
b_thread_desc_,
make_tuple(Number<n0 / NXdlPack>{},
I0,
Number<n0 % NXdlPack>{},
k,
Number<chunk * KThreadChunk>{}),
b_thread_buf);
});
});
static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
constexpr index_t b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
static_ford<Sequence<MRepeat / MXdlPack, NRepeat / NXdlPack, KRepeat / KXdlPack>>{}(
[&](auto mnk) {
constexpr auto m0 = mnk[Number<0>{}];
constexpr auto n0 = mnk[Number<1>{}];
constexpr auto k0 = mnk[Number<2>{}];
static_assert(0 < ScalesPerXdlopsRunPerThread,
"Must have at least one scale per Xdlops "
"per Thread.");
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
constexpr index_t b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
vector_type<AScaleDataType, a_scale_thread_vec_size> a_scale_thread_vec;
vector_type<BScaleDataType, b_scale_thread_vec_size> b_scale_thread_vec;
static_assert(0 < ScalesPerXdlopsRunPerThread,
"Must have at least one scale per Xdlops "
"per Thread.");
// Pack scale_thread_buf into scale_thread_vec
static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) {
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
a_scale_thread_bufs(I1)[Number<a_scale_offset + s>{}];
vector_type<AScaleDataType, a_scale_thread_vec_size> a_scale_thread_vec;
vector_type<BScaleDataType, b_scale_thread_vec_size> b_scale_thread_vec;
// Pack scale_thread_buf into scale_thread_vec
static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) {
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
a_scale_thread_bufs(I1)[Number<a_scale_offset + s>{}];
});
static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) {
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
b_scale_thread_bufs(I1)[Number<b_scale_offset + s>{}];
});
static_ford<Sequence<KXdlPack, MXdlPack, NXdlPack>>{}([&](auto kmn_xdl) {
constexpr auto ikxdl = Number<kmn_xdl[Number<0>{}]>{};
constexpr auto imxdl = Number<kmn_xdl[Number<1>{}]>{};
constexpr auto inxdl = Number<kmn_xdl[Number<2>{}]>{};
constexpr auto kxdl = ikxdl + k0 * KXdlPack;
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, imxdl, kxdl, ik))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, inxdl, kxdl, ik))>{}];
});
static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) {
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
b_scale_thread_bufs(I1)[Number<b_scale_offset + s>{}];
});
using mfma_input_type_a = typename vector_type< //
ComputeTypeA,
xdlops_gemm.K1PerXdlops / APackedSize>::type;
static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
constexpr auto kxdl = ikxdl + k0 * KXdlPack;
using mfma_input_type_b = typename vector_type< //
ComputeTypeB,
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
using mfma_scale_input_type_a = typename vector_type< //
AScaleDataType,
a_scale_thread_vec_size>::type;
using mfma_scale_input_type_b = typename vector_type< //
BScaleDataType,
b_scale_thread_vec_size>::type;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, imxdl, kxdl, ik))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, inxdl, kxdl, ik))>{}];
});
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, imxdl, inxdl, 0));
using mfma_input_type_a = typename vector_type< //
ComputeTypeA,
xdlops_gemm.K1PerXdlops / APackedSize>::type;
using mfma_input_type_b = typename vector_type< //
ComputeTypeB,
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
using mfma_scale_input_type_a = typename vector_type< //
AScaleDataType,
a_scale_thread_vec_size>::type;
using mfma_scale_input_type_b = typename vector_type< //
BScaleDataType,
b_scale_thread_vec_size>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, imxdl, inxdl, 0));
// MFMA accumulation
xdlops_gemm.template Run<ikxdl * MXdlPack + imxdl,
ikxdl * NXdlPack + inxdl>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec
.template AsType<mfma_scale_input_type_a>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec
.template AsType<mfma_scale_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
// MFMA accumulation
xdlops_gemm
.template Run<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec.template AsType<mfma_scale_input_type_a>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec.template AsType<mfma_scale_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
}
else if constexpr(TailNum == TailNumber::Odd)
{
static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
constexpr index_t b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
static_ford<Sequence<MRepeat / MXdlPack, NRepeat / NXdlPack, KRepeat / KXdlPack>>{}(
[&](auto mnk) {
constexpr auto m0 = mnk[Number<0>{}];
constexpr auto n0 = mnk[Number<1>{}];
constexpr auto k0 = mnk[Number<2>{}];
static_assert(0 < ScalesPerXdlopsRunPerThread,
"Must have at least one scale per Xdlops "
"per Thread.");
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
constexpr index_t b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
vector_type<AScaleDataType, a_scale_thread_vec_size> a_scale_thread_vec;
vector_type<BScaleDataType, b_scale_thread_vec_size> b_scale_thread_vec;
static_assert(0 < ScalesPerXdlopsRunPerThread,
"Must have at least one scale per Xdlops "
"per Thread.");
// Pack scale_thread_buf into scale_thread_vec
static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) {
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
a_scale_thread_bufs(I0)[Number<a_scale_offset + s>{}];
vector_type<AScaleDataType, a_scale_thread_vec_size> a_scale_thread_vec;
vector_type<BScaleDataType, b_scale_thread_vec_size> b_scale_thread_vec;
// Pack scale_thread_buf into scale_thread_vec
static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) {
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
a_scale_thread_bufs(I0)[Number<a_scale_offset + s>{}];
});
static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) {
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
b_scale_thread_bufs(I0)[Number<b_scale_offset + s>{}];
});
static_ford<Sequence<KXdlPack, MXdlPack, NXdlPack>>{}([&](auto kmn_xdl) {
constexpr auto ikxdl = Number<kmn_xdl[Number<0>{}]>{};
constexpr auto imxdl = Number<kmn_xdl[Number<1>{}]>{};
constexpr auto inxdl = Number<kmn_xdl[Number<2>{}]>{};
constexpr auto kxdl = ikxdl + k0 * KXdlPack;
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, imxdl, kxdl, ik))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, inxdl, kxdl, ik))>{}];
});
static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) {
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
b_scale_thread_bufs(I0)[Number<b_scale_offset + s>{}];
});
using mfma_input_type_a = typename vector_type< //
ComputeTypeA,
xdlops_gemm.K1PerXdlops / APackedSize>::type;
static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
constexpr auto kxdl = ikxdl + k0 * KXdlPack;
using mfma_input_type_b = typename vector_type< //
ComputeTypeB,
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
using mfma_scale_input_type_a = typename vector_type< //
AScaleDataType,
a_scale_thread_vec_size>::type;
using mfma_scale_input_type_b = typename vector_type< //
BScaleDataType,
b_scale_thread_vec_size>::type;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, imxdl, kxdl, ik))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, inxdl, kxdl, ik))>{}];
});
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, imxdl, inxdl, 0));
using mfma_input_type_a = typename vector_type< //
ComputeTypeA,
xdlops_gemm.K1PerXdlops / APackedSize>::type;
using mfma_input_type_b = typename vector_type< //
ComputeTypeB,
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
using mfma_scale_input_type_a = typename vector_type< //
AScaleDataType,
a_scale_thread_vec_size>::type;
using mfma_scale_input_type_b = typename vector_type< //
BScaleDataType,
b_scale_thread_vec_size>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, imxdl, inxdl, 0));
// MFMA accumulation
xdlops_gemm.template Run<ikxdl * MXdlPack + imxdl,
ikxdl * NXdlPack + inxdl>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec
.template AsType<mfma_scale_input_type_a>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec
.template AsType<mfma_scale_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
// MFMA accumulation
xdlops_gemm
.template Run<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec.template AsType<mfma_scale_input_type_a>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec.template AsType<mfma_scale_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
}
}

View File

@@ -220,69 +220,9 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, k, I0),
a_thread_buf);
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf,
b_thread_desc_,
make_tuple(n0, I0, k, I0),
b_thread_buf);
});
});
});
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
i += 1;
} while(i < (num_loop - 1));
}
// tail
if constexpr(TailNum == TailNumber::Full)
{
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_ford<Sequence<KRepeat, MRepeat>>{}([&](auto km) {
constexpr auto k = Number<km[Number<0>{}]>{};
constexpr auto m0 = Number<km[Number<1>{}]>{};
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf,
@@ -298,34 +238,85 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
b_thread_buf);
});
});
static_ford<Sequence<KRepeat, MRepeat, NRepeat>>{}([&](auto kmn) {
constexpr auto k0 = Number<kmn[Number<0>{}]>{};
constexpr auto m0 = Number<kmn[Number<1>{}]>{};
constexpr auto n0 = Number<kmn[Number<2>{}]>{};
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) = a_thread_buf
[Number<a_thread_desc_.CalculateOffset(make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) = b_thread_buf
[Number<b_thread_desc_.CalculateOffset(make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
i += 1;
} while(i < (num_loop - 1));
}
// tail
if constexpr(TailNum == TailNumber::Full)
{
block_sync_lds();
static_ford<Sequence<KRepeat, MRepeat>>{}([&](auto km) {
constexpr auto k = Number<km[Number<0>{}]>{};
constexpr auto m0 = Number<km[Number<1>{}]>{};
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, k, I0),
a_thread_buf);
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf,
b_thread_desc_,
make_tuple(n0, I0, k, I0),
b_thread_buf);
});
});
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_ford<Sequence<KRepeat, MRepeat, NRepeat>>{}([&](auto kmn) {
constexpr auto k0 = Number<kmn[Number<0>{}]>{};
constexpr auto m0 = Number<kmn[Number<1>{}]>{};
constexpr auto n0 = Number<kmn[Number<2>{}]>{};
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) = a_thread_buf
[Number<a_thread_desc_.CalculateOffset(make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) = b_thread_buf
[Number<b_thread_desc_.CalculateOffset(make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
}
}
@@ -553,51 +544,51 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
__builtin_amdgcn_sched_barrier(0);
}
static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_ford<Sequence<MRepeat, NRepeat>>{}([&](auto mn) {
constexpr auto m0 = Number<mn[Number<0>{}]>{};
constexpr auto n0 = Number<mn[Number<1>{}]>{};
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, k_ + ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, k_ + ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
// The block_sync_lds() here performs double duty:
// A) safeguard against data hazard because barrier from
// blockwise_gemm is moved here B) reduce VMEM FIFO congestion by
// applying small delays to different wavefronts It is performed
// near the end of MAC cluster to minimize lgkmcnt penalty
if constexpr(k0.value == KRepeat - 1 &&
k_.value == KPerInnerLoop - KPack &&
m0.value == MRepeat - 1 && n0.value == NRepeat - 1)
{
__builtin_amdgcn_sched_barrier(0);
block_sync_lds();
__builtin_amdgcn_sched_barrier(0);
}
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0)
{
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_setprio(1);
__builtin_amdgcn_sched_barrier(0);
}
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, k_ + ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, k_ + ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
// The block_sync_lds() here performs double duty:
// A) safeguard against data hazard because barrier from
// blockwise_gemm is moved here B) reduce VMEM FIFO congestion by
// applying small delays to different wavefronts It is performed
// near the end of MAC cluster to minimize lgkmcnt penalty
if constexpr(k0.value == KRepeat - 1 &&
k_.value == KPerInnerLoop - KPack &&
m0.value == MRepeat - 1 && n0.value == NRepeat - 1)
{
__builtin_amdgcn_sched_barrier(0);
block_sync_lds();
__builtin_amdgcn_sched_barrier(0);
}
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0)
{
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_setprio(1);
__builtin_amdgcn_sched_barrier(0);
}
});
});
__builtin_amdgcn_sched_barrier(0);
@@ -642,46 +633,43 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
__builtin_amdgcn_sched_barrier(0);
}
static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_ford<Sequence<MRepeat, NRepeat>>{}([&](auto mn) {
constexpr auto m0 = Number<mn[Number<0>{}]>{};
constexpr auto n0 = Number<mn[Number<1>{}]>{};
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, k_ + ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, k_ + ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
if constexpr(k0.value == KRepeat - 1 &&
k_.value == KPerInnerLoop - KPack &&
m0.value == MRepeat - 1 && n0.value == NRepeat - 1)
{
__builtin_amdgcn_sched_barrier(0);
block_sync_lds();
__builtin_amdgcn_sched_barrier(0);
}
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0)
{
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_setprio(1);
__builtin_amdgcn_sched_barrier(0);
}
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, k_ + ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, k_ + ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
if constexpr(k0.value == KRepeat - 1 && k_.value == KPerInnerLoop - KPack &&
m0.value == MRepeat - 1 && n0.value == NRepeat - 1)
{
__builtin_amdgcn_sched_barrier(0);
block_sync_lds();
__builtin_amdgcn_sched_barrier(0);
}
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0)
{
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_setprio(1);
__builtin_amdgcn_sched_barrier(0);
}
});
});
__builtin_amdgcn_sched_barrier(0);
@@ -942,73 +930,9 @@ struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1<BlockGemmPipelineScheduler::Int
index_t i = 0;
do
{
static_for<0, KRepeat, 1>{}([&](auto k) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, k, I0),
a_thread_buf);
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf,
b_thread_desc_,
make_tuple(n0, I0, k, I0),
b_thread_buf);
});
});
});
block_sync_lds();
a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf);
b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
block_sync_lds_direct_load();
i += 1;
} while(i < (num_loop - 1));
}
// tail
if constexpr(TailNum == TailNumber::Full)
{
static_for<0, KRepeat, 1>{}([&](auto k) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_ford<Sequence<KRepeat, MRepeat>>{}([&](auto km) {
constexpr auto k = Number<km[Number<0>{}]>{};
constexpr auto m0 = Number<km[Number<1>{}]>{};
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf,
@@ -1024,34 +948,89 @@ struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1<BlockGemmPipelineScheduler::Int
b_thread_buf);
});
});
block_sync_lds();
a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf);
b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
static_ford<Sequence<KRepeat, MRepeat, NRepeat>>{}([&](auto kmn) {
constexpr auto k0 = Number<kmn[Number<0>{}]>{};
constexpr auto m0 = Number<kmn[Number<1>{}]>{};
constexpr auto n0 = Number<kmn[Number<2>{}]>{};
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) = a_thread_buf
[Number<a_thread_desc_.CalculateOffset(make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) = b_thread_buf
[Number<b_thread_desc_.CalculateOffset(make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
block_sync_lds_direct_load();
i += 1;
} while(i < (num_loop - 1));
}
// tail
if constexpr(TailNum == TailNumber::Full)
{
static_ford<Sequence<KRepeat, MRepeat>>{}([&](auto km) {
constexpr auto k = Number<km[Number<0>{}]>{};
constexpr auto m0 = Number<km[Number<1>{}]>{};
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, k, I0),
a_thread_buf);
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf,
b_thread_desc_,
make_tuple(n0, I0, k, I0),
b_thread_buf);
});
});
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_ford<Sequence<KRepeat, MRepeat, NRepeat>>{}([&](auto kmn) {
constexpr auto k0 = Number<kmn[Number<0>{}]>{};
constexpr auto m0 = Number<kmn[Number<1>{}]>{};
constexpr auto n0 = Number<kmn[Number<2>{}]>{};
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) = a_thread_buf
[Number<a_thread_desc_.CalculateOffset(make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) = b_thread_buf
[Number<b_thread_desc_.CalculateOffset(make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
}
}

View File

@@ -406,22 +406,19 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
constexpr auto num_scale_m_block = CScaleThreadDesc{}.GetLength(Number<1>{});
constexpr auto num_scale_n_block = CScaleThreadDesc{}.GetLength(Number<2>{});
static_for<0, num_scale_m_block, 1>{}([&](auto m0) {
static_for<0, num_scale_n_block, 1>{}([&](auto n0) {
static_for<0, num_scale_k_block, 1>{}([&](auto k0) {
constexpr index_t c_offset =
CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0));
constexpr index_t a_offset =
AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0));
constexpr index_t b_offset =
BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0));
static_ford<Sequence<num_scale_m_block, num_scale_n_block, num_scale_k_block>>{}(
[&](auto mnk) {
constexpr auto m0 = Number<mnk[Number<0>{}]>{};
constexpr auto n0 = Number<mnk[Number<1>{}]>{};
constexpr auto k0 = Number<mnk[Number<2>{}]>{};
constexpr index_t c_offset =
CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0));
constexpr index_t a_offset = AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0));
constexpr index_t b_offset = BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0));
c_scale_thread_buf(Number<c_offset>{}) =
a_scale_thread_buf[Number<a_offset>{}] *
b_scale_thread_buf[Number<b_offset>{}];
});
c_scale_thread_buf(Number<c_offset>{}) =
a_scale_thread_buf[Number<a_offset>{}] * b_scale_thread_buf[Number<b_offset>{}];
});
});
// Local prefill 1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
@@ -512,74 +509,64 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) {
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
.template AsType<AccDataType>()(Number<t>{}) = 0;
});
static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) {
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_ford<Sequence<MRepeat, NRepeat, num_scale_k_block>>{}([&](auto mnk) {
constexpr auto m0 = Number<mnk[Number<0>{}]>{};
constexpr auto n0 = Number<mnk[Number<1>{}]>{};
constexpr auto kscale0 = Number<mnk[Number<2>{}]>{};
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
.template AsType<AccDataType>()(Number<t>{}) = 0;
});
static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) {
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0,
I0,
kscale0 * KRepeat / num_scale_k_block + k0,
ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0,
I0,
kscale0 * KRepeat / num_scale_k_block + k0,
ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf,
xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
});
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
constexpr index_t cscale_offset =
CScaleThreadDesc{}.CalculateOffset(
make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat));
c_thread_buf(Number<c_offset>{}) +=
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
.template AsType<AccDataType>()[Number<t>{}] *
type_convert<AccDataType>(
c_scale_thread_buf[Number<cscale_offset>{}]);
});
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple(
m0, I0, kscale0 * KRepeat / num_scale_k_block + k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(make_tuple(
n0, I0, kscale0 * KRepeat / num_scale_k_block + k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
});
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset(
make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat));
c_thread_buf(Number<c_offset>{}) +=
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
.template AsType<AccDataType>()[Number<t>{}] *
type_convert<AccDataType>(c_scale_thread_buf[Number<cscale_offset>{}]);
});
});
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, num_scale_n_block, 1>{}([&](auto n0) {
static_for<0, num_scale_k_block, 1>{}([&](auto k0) {
constexpr index_t c_offset =
CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0));
constexpr index_t a_offset =
AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0));
constexpr index_t b_offset =
BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0));
static_ford<Sequence<MRepeat, num_scale_n_block, num_scale_k_block>>{}(
[&](auto mnk) {
constexpr auto m0 = Number<mnk[Number<0>{}]>{};
constexpr auto n0 = Number<mnk[Number<1>{}]>{};
constexpr auto k0 = Number<mnk[Number<2>{}]>{};
constexpr index_t c_offset =
CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0));
constexpr index_t a_offset =
AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0));
constexpr index_t b_offset =
BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0));
c_scale_thread_buf(Number<c_offset>{}) =
a_scale_thread_buf[Number<a_offset>{}] *
b_scale_thread_buf[Number<b_offset>{}];
});
c_scale_thread_buf(Number<c_offset>{}) =
a_scale_thread_buf[Number<a_offset>{}] *
b_scale_thread_buf[Number<b_offset>{}];
});
});
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) {
@@ -642,72 +629,59 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) {
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
.template AsType<AccDataType>()(Number<t>{}) = 0;
});
static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) {
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_ford<Sequence<MRepeat, NRepeat, num_scale_k_block>>{}([&](auto mnk) {
constexpr auto m0 = Number<mnk[Number<0>{}]>{};
constexpr auto n0 = Number<mnk[Number<1>{}]>{};
constexpr auto kscale0 = Number<mnk[Number<2>{}]>{};
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
.template AsType<AccDataType>()(Number<t>{}) = 0;
});
static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) {
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0,
I0,
kscale0 * KRepeat / num_scale_k_block + k0,
ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0,
I0,
kscale0 * KRepeat / num_scale_k_block + k0,
ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf,
xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
});
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset(
make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat));
c_thread_buf(Number<c_offset>{}) +=
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
.template AsType<AccDataType>()[Number<t>{}] *
type_convert<AccDataType>(
c_scale_thread_buf[Number<cscale_offset>{}]);
});
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple(
m0, I0, kscale0 * KRepeat / num_scale_k_block + k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(make_tuple(
n0, I0, kscale0 * KRepeat / num_scale_k_block + k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
});
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset(
make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat));
c_thread_buf(Number<c_offset>{}) +=
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
.template AsType<AccDataType>()[Number<t>{}] *
type_convert<AccDataType>(c_scale_thread_buf[Number<cscale_offset>{}]);
});
});
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, num_scale_n_block, 1>{}([&](auto n0) {
static_for<0, num_scale_k_block, 1>{}([&](auto k0) {
constexpr index_t c_offset =
CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0));
constexpr index_t a_offset =
AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0));
constexpr index_t b_offset =
BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0));
static_ford<Sequence<MRepeat, num_scale_n_block, num_scale_k_block>>{}([&](auto mnk) {
constexpr auto m0 = Number<mnk[Number<0>{}]>{};
constexpr auto n0 = Number<mnk[Number<1>{}]>{};
constexpr auto k0 = Number<mnk[Number<2>{}]>{};
constexpr index_t c_offset =
CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0));
constexpr index_t a_offset = AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0));
constexpr index_t b_offset = BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0));
c_scale_thread_buf(Number<c_offset>{}) =
a_scale_thread_buf[Number<a_offset>{}] *
b_scale_thread_buf[Number<b_offset>{}];
});
});
c_scale_thread_buf(Number<c_offset>{}) =
a_scale_thread_buf[Number<a_offset>{}] * b_scale_thread_buf[Number<b_offset>{}];
});
block_sync_lds();
@@ -733,108 +707,90 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) {
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
.template AsType<AccDataType>()(Number<t>{}) = 0;
});
static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) {
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_ford<Sequence<MRepeat, NRepeat, num_scale_k_block>>{}([&](auto mnk) {
constexpr auto m0 = Number<mnk[Number<0>{}]>{};
constexpr auto n0 = Number<mnk[Number<1>{}]>{};
constexpr auto kscale0 = Number<mnk[Number<2>{}]>{};
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
.template AsType<AccDataType>()(Number<t>{}) = 0;
});
static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) {
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0,
I0,
kscale0 * KRepeat / num_scale_k_block + k0,
ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0,
I0,
kscale0 * KRepeat / num_scale_k_block + k0,
ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf,
xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
});
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset(
make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat));
c_thread_buf(Number<c_offset>{}) +=
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
.template AsType<AccDataType>()[Number<t>{}] *
type_convert<AccDataType>(
c_scale_thread_buf[Number<cscale_offset>{}]);
});
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple(
m0, I0, kscale0 * KRepeat / num_scale_k_block + k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(make_tuple(
n0, I0, kscale0 * KRepeat / num_scale_k_block + k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
});
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset(
make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat));
c_thread_buf(Number<c_offset>{}) +=
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
.template AsType<AccDataType>()[Number<t>{}] *
type_convert<AccDataType>(c_scale_thread_buf[Number<cscale_offset>{}]);
});
});
__builtin_amdgcn_sched_barrier(0);
}
else if constexpr(TailNum == TailNumber::Odd)
{
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) {
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
.template AsType<AccDataType>()(Number<t>{}) = 0;
});
static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) {
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_ford<Sequence<MRepeat, NRepeat, num_scale_k_block>>{}([&](auto mnk) {
constexpr auto m0 = Number<mnk[Number<0>{}]>{};
constexpr auto n0 = Number<mnk[Number<1>{}]>{};
constexpr auto kscale0 = Number<mnk[Number<2>{}]>{};
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
.template AsType<AccDataType>()(Number<t>{}) = 0;
});
static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) {
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0,
I0,
kscale0 * KRepeat / num_scale_k_block + k0,
ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0,
I0,
kscale0 * KRepeat / num_scale_k_block + k0,
ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf,
xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
});
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset(
make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat));
c_thread_buf(Number<c_offset>{}) +=
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
.template AsType<AccDataType>()[Number<t>{}] *
type_convert<AccDataType>(
c_scale_thread_buf[Number<cscale_offset>{}]);
});
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple(
m0, I0, kscale0 * KRepeat / num_scale_k_block + k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(make_tuple(
n0, I0, kscale0 * KRepeat / num_scale_k_block + k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
});
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset(
make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat));
c_thread_buf(Number<c_offset>{}) +=
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
.template AsType<AccDataType>()[Number<t>{}] *
type_convert<AccDataType>(c_scale_thread_buf[Number<cscale_offset>{}]);
});
});
__builtin_amdgcn_sched_barrier(0);

View File

@@ -277,38 +277,37 @@ struct BlockwiseGemmXdlops_pipeline_v1_b_scale<BlockGemmPipelineScheduler::Intra
});
});
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
c_thread_buf_per_scale.Clear();
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_ford<Sequence<MRepeat, NRepeat>>{}([&](auto mn) {
constexpr auto m0 = Number<mn[Number<0>{}]>{};
constexpr auto n0 = Number<mn[Number<1>{}]>{};
c_thread_buf_per_scale.Clear();
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf,
xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0));
});
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
c_thread_buf(Number<c_offset>{}) +=
c_thread_buf_per_scale[Number<t>{}] *
type_convert<AccDataType>(b_scale_thread_buf[n0]);
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0));
});
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
c_thread_buf(Number<c_offset>{}) +=
c_thread_buf_per_scale[Number<t>{}] *
type_convert<AccDataType>(b_scale_thread_buf[n0]);
});
});
@@ -358,37 +357,34 @@ struct BlockwiseGemmXdlops_pipeline_v1_b_scale<BlockGemmPipelineScheduler::Intra
});
});
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
c_thread_buf_per_scale.Clear();
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_ford<Sequence<MRepeat, NRepeat>>{}([&](auto mn) {
constexpr auto m0 = Number<mn[Number<0>{}]>{};
constexpr auto n0 = Number<mn[Number<1>{}]>{};
c_thread_buf_per_scale.Clear();
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0));
});
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
c_thread_buf(Number<c_offset>{}) +=
c_thread_buf_per_scale[Number<t>{}] *
type_convert<AccDataType>(b_scale_thread_buf[n0]);
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) = a_thread_buf
[Number<a_thread_desc_.CalculateOffset(make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) = b_thread_buf
[Number<b_thread_desc_.CalculateOffset(make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run<>(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0));
});
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
c_thread_buf(Number<c_offset>{}) +=
c_thread_buf_per_scale[Number<t>{}] *
type_convert<AccDataType>(b_scale_thread_buf[n0]);
});
});
}

View File

@@ -331,56 +331,60 @@ struct BlockwiseGemmXdlops_pipeline_v1_mx<BlockGemmPipelineScheduler::Intrawave,
// t32: |32 --> 47 96 --> 111| 160 --> 175 224 --> 239| etc.
// t48: |48 --> 63 112 --> 127| 176 --> 191 240 --> 255| etc.
// k = 0 k = 1
static_for<0, KRepeat, 1>{}([&](auto k) {
constexpr auto k_step =
k * xdlops_gemm.KPerXdlops * KPack / xdlops_gemm.K1PerXdlops;
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, xdlops_gemm.K1PerXdlops / APackedSize / KThreadChunk, 1>{}(
[&](auto chunk) {
constexpr auto a_k_step_chunk =
k_step +
chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k,
make_tuple(Number<m0 / MXdlPack>{},
I0,
Number<m0 % MXdlPack>{},
I0,
Number<a_k_step_chunk>{}),
a_block_buf,
a_thread_desc_,
make_tuple(Number<m0 / MXdlPack>{},
I0,
Number<m0 % MXdlPack>{},
k,
Number<chunk * KThreadChunk>{}),
a_thread_buf);
});
// LDS reads for A
static_ford<Sequence<KRepeat,
MRepeat,
xdlops_gemm.K1PerXdlops / APackedSize / KThreadChunk>>{}(
[&](auto km_chunk) {
constexpr auto k = Number<km_chunk[Number<0>{}]>{};
constexpr auto m0 = Number<km_chunk[Number<1>{}]>{};
constexpr auto chunk = Number<km_chunk[Number<2>{}]>{};
constexpr auto k_step =
k * xdlops_gemm.KPerXdlops * KPack / xdlops_gemm.K1PerXdlops;
constexpr auto a_k_step_chunk =
k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k,
make_tuple(Number<m0 / MXdlPack>{},
I0,
Number<m0 % MXdlPack>{},
I0,
Number<a_k_step_chunk>{}),
a_block_buf,
a_thread_desc_,
make_tuple(Number<m0 / MXdlPack>{},
I0,
Number<m0 % MXdlPack>{},
k,
Number<chunk * KThreadChunk>{}),
a_thread_buf);
});
static_for<0, NRepeat, 1>{}([&](auto n0) {
// read block data in chunks to assemble correct thread vectors
static_for<0, xdlops_gemm.K1PerXdlops / BPackedSize / KThreadChunk, 1>{}(
[&](auto chunk) {
constexpr auto b_k_step_chunk =
k_step +
chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k,
make_tuple(Number<n0 / NXdlPack>{},
I0,
Number<n0 % NXdlPack>{},
I0,
Number<b_k_step_chunk>{}),
b_block_buf,
b_thread_desc_,
make_tuple(Number<n0 / NXdlPack>{},
I0,
Number<n0 % NXdlPack>{},
k,
Number<chunk * KThreadChunk>{}),
b_thread_buf);
});
// LDS reads for B
static_ford<Sequence<KRepeat,
NRepeat,
xdlops_gemm.K1PerXdlops / BPackedSize / KThreadChunk>>{}(
[&](auto kn_chunk) {
constexpr auto k = Number<kn_chunk[Number<0>{}]>{};
constexpr auto n0 = Number<kn_chunk[Number<1>{}]>{};
constexpr auto chunk = Number<kn_chunk[Number<2>{}]>{};
constexpr auto k_step =
k * xdlops_gemm.KPerXdlops * KPack / xdlops_gemm.K1PerXdlops;
constexpr auto b_k_step_chunk =
k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k,
make_tuple(Number<n0 / NXdlPack>{},
I0,
Number<n0 % NXdlPack>{},
I0,
Number<b_k_step_chunk>{}),
b_block_buf,
b_thread_desc_,
make_tuple(Number<n0 / NXdlPack>{},
I0,
Number<n0 % NXdlPack>{},
k,
Number<chunk * KThreadChunk>{}),
b_thread_buf);
});
});
// load for next k loop
block_sync_lds();
@@ -389,82 +393,78 @@ struct BlockwiseGemmXdlops_pipeline_v1_mx<BlockGemmPipelineScheduler::Intrawave,
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
constexpr index_t b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
static_ford<Sequence<MRepeat / MXdlPack, NRepeat / NXdlPack, KRepeat / KXdlPack>>{}(
[&](auto mnk) {
constexpr auto m0 = Number<mnk[Number<0>{}]>{};
constexpr auto n0 = Number<mnk[Number<1>{}]>{};
constexpr auto k0 = Number<mnk[Number<2>{}]>{};
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
constexpr index_t b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
static_assert(0 < ScalesPerXdlopsRunPerThreadA &&
0 < ScalesPerXdlopsRunPerThreadB,
"Must have at least one scale per Xdlops per Thread.");
static_assert(0 < ScalesPerXdlopsRunPerThreadA &&
0 < ScalesPerXdlopsRunPerThreadB,
"Must have at least one scale per Xdlops per Thread.");
vector_type<AScaleDataType, a_scale_thread_vec_size> a_scale_thread_vec;
vector_type<BScaleDataType, b_scale_thread_vec_size> b_scale_thread_vec;
vector_type<AScaleDataType, a_scale_thread_vec_size> a_scale_thread_vec;
vector_type<BScaleDataType, b_scale_thread_vec_size> b_scale_thread_vec;
// Pack scale_thread_buf into scale_thread_vec
static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) {
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
a_scale_thread_buf[Number<a_scale_offset + s>{}];
});
static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) {
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
b_scale_thread_buf[Number<b_scale_offset + s>{}];
// Pack scale_thread_buf into scale_thread_vec
static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) {
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
a_scale_thread_buf[Number<a_scale_offset + s>{}];
});
static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) {
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
b_scale_thread_buf[Number<b_scale_offset + s>{}];
});
static_ford<Sequence<KXdlPack, MXdlPack, NXdlPack>>{}([&](auto kmn_xdl) {
constexpr auto ikxdl = Number<kmn_xdl[Number<0>{}]>{};
constexpr auto imxdl = Number<kmn_xdl[Number<1>{}]>{};
constexpr auto inxdl = Number<kmn_xdl[Number<2>{}]>{};
constexpr auto kxdl = ikxdl + k0 * KXdlPack;
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, imxdl, kxdl, ik))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, inxdl, kxdl, ik))>{}];
});
static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
constexpr auto kxdl = ikxdl + k0 * KXdlPack;
using mfma_input_type_a = typename vector_type< //
ComputeTypeA,
xdlops_gemm.K1PerXdlops / APackedSize>::type;
using mfma_input_type_b = typename vector_type< //
ComputeTypeB,
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
using mfma_scale_input_type_a = typename vector_type< //
AScaleDataType,
a_scale_thread_vec_size>::type;
using mfma_scale_input_type_b = typename vector_type< //
BScaleDataType,
b_scale_thread_vec_size>::type;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, imxdl, kxdl, ik))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, inxdl, kxdl, ik))>{}];
});
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, imxdl, inxdl, 0));
using mfma_input_type_a = typename vector_type< //
ComputeTypeA,
xdlops_gemm.K1PerXdlops / APackedSize>::type;
using mfma_input_type_b = typename vector_type< //
ComputeTypeB,
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
using mfma_scale_input_type_a = typename vector_type< //
AScaleDataType,
a_scale_thread_vec_size>::type;
using mfma_scale_input_type_b = typename vector_type< //
BScaleDataType,
b_scale_thread_vec_size>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, imxdl, inxdl, 0));
// MFMA accumulation
xdlops_gemm.template Run<ikxdl * MXdlPack + imxdl,
ikxdl * NXdlPack + inxdl>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec
.template AsType<mfma_scale_input_type_a>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec
.template AsType<mfma_scale_input_type_b>(),
c_thread_buf.GetVectorTypeReference(
Number<c_offset>{}));
});
});
});
// MFMA accumulation
xdlops_gemm
.template Run<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec.template AsType<mfma_scale_input_type_a>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec.template AsType<mfma_scale_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
// Prefetch a_scales
static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
@@ -519,131 +519,130 @@ struct BlockwiseGemmXdlops_pipeline_v1_mx<BlockGemmPipelineScheduler::Intrawave,
// tail
if constexpr(TailNum == TailNumber::Full)
{
static_for<0, KRepeat, 1>{}([&](auto k) {
constexpr auto k_step =
k * xdlops_gemm.KPerXdlops * KPack / xdlops_gemm.K1PerXdlops;
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, xdlops_gemm.K1PerXdlops / APackedSize / KThreadChunk, 1>{}(
[&](auto chunk) {
constexpr auto a_k_step_chunk =
k_step +
chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k,
make_tuple(Number<m0 / MXdlPack>{},
I0,
Number<m0 % MXdlPack>{},
I0,
Number<a_k_step_chunk>{}),
a_block_buf,
a_thread_desc_,
make_tuple(Number<m0 / MXdlPack>{},
I0,
Number<m0 % MXdlPack>{},
k,
Number<chunk * KThreadChunk>{}),
a_thread_buf);
});
// LDS reads for A
static_ford<
Sequence<KRepeat, MRepeat, xdlops_gemm.K1PerXdlops / APackedSize / KThreadChunk>>{}(
[&](auto km_chunk) {
constexpr auto k = Number<km_chunk[Number<0>{}]>{};
constexpr auto m0 = Number<km_chunk[Number<1>{}]>{};
constexpr auto chunk = Number<km_chunk[Number<2>{}]>{};
constexpr auto k_step =
k * xdlops_gemm.KPerXdlops * KPack / xdlops_gemm.K1PerXdlops;
constexpr auto a_k_step_chunk =
k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k,
make_tuple(Number<m0 / MXdlPack>{},
I0,
Number<m0 % MXdlPack>{},
I0,
Number<a_k_step_chunk>{}),
a_block_buf,
a_thread_desc_,
make_tuple(Number<m0 / MXdlPack>{},
I0,
Number<m0 % MXdlPack>{},
k,
Number<chunk * KThreadChunk>{}),
a_thread_buf);
});
static_for<0, NRepeat, 1>{}([&](auto n0) {
// read block data in chunks to assemble correct thread vectors
static_for<0, xdlops_gemm.K1PerXdlops / BPackedSize / KThreadChunk, 1>{}(
[&](auto chunk) {
constexpr auto b_k_step_chunk =
k_step +
chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k,
make_tuple(Number<n0 / NXdlPack>{},
I0,
Number<n0 % NXdlPack>{},
I0,
Number<b_k_step_chunk>{}),
b_block_buf,
b_thread_desc_,
make_tuple(Number<n0 / NXdlPack>{},
I0,
Number<n0 % NXdlPack>{},
k,
Number<chunk * KThreadChunk>{}),
b_thread_buf);
});
// LDS reads for B
static_ford<
Sequence<KRepeat, NRepeat, xdlops_gemm.K1PerXdlops / BPackedSize / KThreadChunk>>{}(
[&](auto kn_chunk) {
constexpr auto k = Number<kn_chunk[Number<0>{}]>{};
constexpr auto n0 = Number<kn_chunk[Number<1>{}]>{};
constexpr auto chunk = Number<kn_chunk[Number<2>{}]>{};
constexpr auto k_step =
k * xdlops_gemm.KPerXdlops * KPack / xdlops_gemm.K1PerXdlops;
constexpr auto b_k_step_chunk =
k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k,
make_tuple(Number<n0 / NXdlPack>{},
I0,
Number<n0 % NXdlPack>{},
I0,
Number<b_k_step_chunk>{}),
b_block_buf,
b_thread_desc_,
make_tuple(Number<n0 / NXdlPack>{},
I0,
Number<n0 % NXdlPack>{},
k,
Number<chunk * KThreadChunk>{}),
b_thread_buf);
});
});
static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
constexpr index_t b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
static_ford<Sequence<MRepeat / MXdlPack, NRepeat / NXdlPack, KRepeat / KXdlPack>>{}(
[&](auto mnk) {
constexpr auto m0 = Number<mnk[Number<0>{}]>{};
constexpr auto n0 = Number<mnk[Number<1>{}]>{};
constexpr auto k0 = Number<mnk[Number<2>{}]>{};
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
constexpr index_t b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
static_assert(0 < ScalesPerXdlopsRunPerThreadA &&
0 < ScalesPerXdlopsRunPerThreadB,
"Must have at least one scale per Xdlops per Thread.");
static_assert(0 < ScalesPerXdlopsRunPerThreadA &&
0 < ScalesPerXdlopsRunPerThreadB,
"Must have at least one scale per Xdlops per Thread.");
vector_type<AScaleDataType, a_scale_thread_vec_size> a_scale_thread_vec;
vector_type<BScaleDataType, b_scale_thread_vec_size> b_scale_thread_vec;
vector_type<AScaleDataType, a_scale_thread_vec_size> a_scale_thread_vec;
vector_type<BScaleDataType, b_scale_thread_vec_size> b_scale_thread_vec;
// Pack scale_thread_buf into scale_thread_vec
static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) {
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
a_scale_thread_buf[Number<a_scale_offset + s>{}];
});
static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) {
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
b_scale_thread_buf[Number<b_scale_offset + s>{}];
// Pack scale_thread_buf into scale_thread_vec
static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) {
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
a_scale_thread_buf[Number<a_scale_offset + s>{}];
});
static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) {
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
b_scale_thread_buf[Number<b_scale_offset + s>{}];
});
static_ford<Sequence<KXdlPack, MXdlPack, NXdlPack>>{}([&](auto kmn_xdl) {
constexpr auto ikxdl = Number<kmn_xdl[Number<0>{}]>{};
constexpr auto imxdl = Number<kmn_xdl[Number<1>{}]>{};
constexpr auto inxdl = Number<kmn_xdl[Number<2>{}]>{};
constexpr auto kxdl = ikxdl + k0 * KXdlPack;
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, imxdl, kxdl, ik))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, inxdl, kxdl, ik))>{}];
});
static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
constexpr auto kxdl = ikxdl + k0 * KXdlPack;
using mfma_input_type_a = typename vector_type< //
ComputeTypeA,
xdlops_gemm.K1PerXdlops / APackedSize>::type;
using mfma_input_type_b = typename vector_type< //
ComputeTypeB,
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
using mfma_scale_input_type_a = typename vector_type< //
AScaleDataType,
a_scale_thread_vec_size>::type;
using mfma_scale_input_type_b = typename vector_type< //
BScaleDataType,
b_scale_thread_vec_size>::type;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, imxdl, kxdl, ik))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, inxdl, kxdl, ik))>{}];
});
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, imxdl, inxdl, 0));
using mfma_input_type_a = typename vector_type< //
ComputeTypeA,
xdlops_gemm.K1PerXdlops / APackedSize>::type;
using mfma_input_type_b = typename vector_type< //
ComputeTypeB,
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
using mfma_scale_input_type_a = typename vector_type< //
AScaleDataType,
a_scale_thread_vec_size>::type;
using mfma_scale_input_type_b = typename vector_type< //
BScaleDataType,
b_scale_thread_vec_size>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, imxdl, inxdl, 0));
// MFMA accumulation
xdlops_gemm.template Run<ikxdl * MXdlPack + imxdl,
ikxdl * NXdlPack + inxdl>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec
.template AsType<mfma_scale_input_type_a>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec
.template AsType<mfma_scale_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
// MFMA accumulation
xdlops_gemm
.template Run<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec.template AsType<mfma_scale_input_type_a>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec.template AsType<mfma_scale_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
}
}

View File

@@ -283,34 +283,31 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave,
});
});
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_ford<Sequence<KRepeat, MRepeat, NRepeat>>{}([&](auto kmn) {
constexpr auto k0 = Number<kmn[Number<0>{}]>{};
constexpr auto m0 = Number<kmn[Number<1>{}]>{};
constexpr auto n0 = Number<kmn[Number<2>{}]>{};
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
block_sync_lds();
@@ -354,34 +351,29 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave,
});
});
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_ford<Sequence<KRepeat, MRepeat, NRepeat>>{}([&](auto kmn) {
constexpr auto k0 = Number<kmn[Number<0>{}]>{};
constexpr auto m0 = Number<kmn[Number<1>{}]>{};
constexpr auto n0 = Number<kmn[Number<2>{}]>{};
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) = a_thread_buf
[Number<a_thread_desc_.CalculateOffset(make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) = b_thread_buf
[Number<b_thread_desc_.CalculateOffset(make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
block_sync_lds();
@@ -409,32 +401,28 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave,
});
});
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_ford<Sequence<KRepeat, MRepeat, NRepeat>>{}([&](auto kmn) {
constexpr auto k0 = Number<kmn[Number<0>{}]>{};
constexpr auto m0 = Number<kmn[Number<1>{}]>{};
constexpr auto n0 = Number<kmn[Number<2>{}]>{};
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) = a_thread_buf
[Number<a_thread_desc_.CalculateOffset(make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) = b_thread_buf
[Number<b_thread_desc_.CalculateOffset(make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
};
@@ -460,32 +448,28 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave,
});
});
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_ford<Sequence<KRepeat, MRepeat, NRepeat>>{}([&](auto kmn) {
constexpr auto k0 = Number<kmn[Number<0>{}]>{};
constexpr auto m0 = Number<kmn[Number<1>{}]>{};
constexpr auto n0 = Number<kmn[Number<2>{}]>{};
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) = a_thread_buf
[Number<a_thread_desc_.CalculateOffset(make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) = b_thread_buf
[Number<b_thread_desc_.CalculateOffset(make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
}
else if constexpr(TailNum == TailNumber::Two)
@@ -788,52 +772,52 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave,
__builtin_amdgcn_sched_barrier(0);
}
static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_ford<Sequence<MRepeat, NRepeat>>{}([&](auto mn) {
constexpr auto m0 = Number<mn[Number<0>{}]>{};
constexpr auto n0 = Number<mn[Number<1>{}]>{};
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, k_ + ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, k_ + ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
// The block_sync_lds() here performs double duty:
// A) safeguard against data hazard because barrier from
// blockwise_gemm is moved here B) reduce VMEM FIFO congestion
// by applying small delays to different wavefronts It is
// performed near the end of MAC cluster to minimize lgkmcnt
// penalty
if constexpr(k0.value == KRepeat - 1 &&
k_.value == KPerInnerLoop - KPack &&
m0.value == MRepeat - 1 && n0.value == NRepeat - 1)
{
__builtin_amdgcn_sched_barrier(0);
block_sync_lds();
__builtin_amdgcn_sched_barrier(0);
}
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0)
{
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_setprio(1);
__builtin_amdgcn_sched_barrier(0);
}
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, k_ + ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, k_ + ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
// The block_sync_lds() here performs double duty:
// A) safeguard against data hazard because barrier from
// blockwise_gemm is moved here B) reduce VMEM FIFO congestion
// by applying small delays to different wavefronts It is
// performed near the end of MAC cluster to minimize lgkmcnt
// penalty
if constexpr(k0.value == KRepeat - 1 &&
k_.value == KPerInnerLoop - KPack &&
m0.value == MRepeat - 1 && n0.value == NRepeat - 1)
{
__builtin_amdgcn_sched_barrier(0);
block_sync_lds();
__builtin_amdgcn_sched_barrier(0);
}
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0)
{
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_setprio(1);
__builtin_amdgcn_sched_barrier(0);
}
});
});
__builtin_amdgcn_sched_barrier(0);
@@ -887,46 +871,46 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave,
__builtin_amdgcn_sched_barrier(0);
}
static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_ford<Sequence<MRepeat, NRepeat>>{}([&](auto mn) {
constexpr auto m0 = Number<mn[Number<0>{}]>{};
constexpr auto n0 = Number<mn[Number<1>{}]>{};
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, k_ + ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, k_ + ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
if constexpr(k0.value == KRepeat - 1 &&
k_.value == KPerInnerLoop - KPack &&
m0.value == MRepeat - 1 && n0.value == NRepeat - 1)
{
__builtin_amdgcn_sched_barrier(0);
block_sync_lds();
__builtin_amdgcn_sched_barrier(0);
}
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0)
{
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_setprio(1);
__builtin_amdgcn_sched_barrier(0);
}
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, k_ + ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, k_ + ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
if constexpr(k0.value == KRepeat - 1 &&
k_.value == KPerInnerLoop - KPack &&
m0.value == MRepeat - 1 && n0.value == NRepeat - 1)
{
__builtin_amdgcn_sched_barrier(0);
block_sync_lds();
__builtin_amdgcn_sched_barrier(0);
}
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0)
{
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_setprio(1);
__builtin_amdgcn_sched_barrier(0);
}
});
});
__builtin_amdgcn_sched_barrier(0);
@@ -963,46 +947,43 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave,
__builtin_amdgcn_sched_barrier(0);
}
static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_ford<Sequence<MRepeat, NRepeat>>{}([&](auto mn) {
constexpr auto m0 = Number<mn[Number<0>{}]>{};
constexpr auto n0 = Number<mn[Number<1>{}]>{};
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, k_ + ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, k_ + ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
if constexpr(k0.value == KRepeat - 1 &&
k_.value == KPerInnerLoop - KPack &&
m0.value == MRepeat - 1 && n0.value == NRepeat - 1)
{
__builtin_amdgcn_sched_barrier(0);
block_sync_lds();
__builtin_amdgcn_sched_barrier(0);
}
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0)
{
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_setprio(1);
__builtin_amdgcn_sched_barrier(0);
}
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, k_ + ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, k_ + ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
if constexpr(k0.value == KRepeat - 1 && k_.value == KPerInnerLoop - KPack &&
m0.value == MRepeat - 1 && n0.value == NRepeat - 1)
{
__builtin_amdgcn_sched_barrier(0);
block_sync_lds();
__builtin_amdgcn_sched_barrier(0);
}
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0)
{
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_setprio(1);
__builtin_amdgcn_sched_barrier(0);
}
});
});
__builtin_amdgcn_sched_barrier(0);
@@ -1039,46 +1020,43 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave,
__builtin_amdgcn_sched_barrier(0);
}
static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_ford<Sequence<MRepeat, NRepeat>>{}([&](auto mn) {
constexpr auto m0 = Number<mn[Number<0>{}]>{};
constexpr auto n0 = Number<mn[Number<1>{}]>{};
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, k_ + ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, k_ + ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
if constexpr(k0.value == KRepeat - 1 &&
k_.value == KPerInnerLoop - KPack &&
m0.value == MRepeat - 1 && n0.value == NRepeat - 1)
{
__builtin_amdgcn_sched_barrier(0);
block_sync_lds();
__builtin_amdgcn_sched_barrier(0);
}
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0)
{
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_setprio(1);
__builtin_amdgcn_sched_barrier(0);
}
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, k_ + ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, k_ + ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
if constexpr(k0.value == KRepeat - 1 && k_.value == KPerInnerLoop - KPack &&
m0.value == MRepeat - 1 && n0.value == NRepeat - 1)
{
__builtin_amdgcn_sched_barrier(0);
block_sync_lds();
__builtin_amdgcn_sched_barrier(0);
}
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0)
{
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_setprio(1);
__builtin_amdgcn_sched_barrier(0);
}
});
});
__builtin_amdgcn_sched_barrier(0);

View File

@@ -349,39 +349,39 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
});
});
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
c_thread_buf_per_scale.Clear();
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_ford<Sequence<MRepeat, NRepeat>>{}([&](auto mn) {
constexpr auto m0 = Number<mn[Number<0>{}]>{};
constexpr auto n0 = Number<mn[Number<1>{}]>{};
c_thread_buf_per_scale.Clear();
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf,
xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0));
});
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
c_thread_buf(Number<c_offset>{}) +=
c_thread_buf_per_scale[Number<t>{}] *
type_convert<AccDataType>(a_scale_thread_buf[m0]) *
type_convert<AccDataType>(b_scale_thread_buf[I0]);
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf,
xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0));
});
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
c_thread_buf(Number<c_offset>{}) +=
c_thread_buf_per_scale[Number<t>{}] *
type_convert<AccDataType>(a_scale_thread_buf[m0]) *
type_convert<AccDataType>(b_scale_thread_buf[I0]);
});
});
@@ -436,58 +436,57 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
auto LoopTailFunc = [&](auto tail_num) {
static_for<1, tail_num, 1>{}([&](auto iprefetch) {
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, k, I0),
a_thread_buf);
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf,
b_thread_desc_,
make_tuple(n0, I0, k, I0),
b_thread_buf);
});
static_ford<Sequence<KRepeat, MRepeat>>{}([&](auto km) {
constexpr auto k = Number<km[Number<0>{}]>{};
constexpr auto m0 = Number<km[Number<1>{}]>{};
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, k, I0),
a_thread_buf);
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf,
b_thread_desc_,
make_tuple(n0, I0, k, I0),
b_thread_buf);
});
});
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
c_thread_buf_per_scale.Clear();
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_ford<Sequence<MRepeat, NRepeat>>{}([&](auto mn) {
constexpr auto m0 = Number<mn[Number<0>{}]>{};
constexpr auto n0 = Number<mn[Number<1>{}]>{};
c_thread_buf_per_scale.Clear();
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf,
xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0));
});
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
c_thread_buf(Number<c_offset>{}) +=
c_thread_buf_per_scale[Number<t>{}] *
type_convert<AccDataType>(a_scale_thread_buf[m0]) *
type_convert<AccDataType>(b_scale_thread_buf[I0]);
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0));
});
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
c_thread_buf(Number<c_offset>{}) +=
c_thread_buf_per_scale[Number<t>{}] *
type_convert<AccDataType>(a_scale_thread_buf[m0]) *
type_convert<AccDataType>(b_scale_thread_buf[I0]);
});
});
@@ -526,57 +525,54 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
});
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, k, I0),
a_thread_buf);
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf,
b_thread_desc_,
make_tuple(n0, I0, k, I0),
b_thread_buf);
});
static_ford<Sequence<KRepeat, MRepeat>>{}([&](auto km) {
constexpr auto k = Number<km[Number<0>{}]>{};
constexpr auto m0 = Number<km[Number<1>{}]>{};
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, k, I0),
a_thread_buf);
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf,
b_thread_desc_,
make_tuple(n0, I0, k, I0),
b_thread_buf);
});
});
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
c_thread_buf_per_scale.Clear();
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_ford<Sequence<MRepeat, NRepeat>>{}([&](auto mn) {
constexpr auto m0 = Number<mn[Number<0>{}]>{};
constexpr auto n0 = Number<mn[Number<1>{}]>{};
c_thread_buf_per_scale.Clear();
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0));
});
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
c_thread_buf(Number<c_offset>{}) +=
c_thread_buf_per_scale[Number<t>{}] *
type_convert<AccDataType>(a_scale_thread_buf[m0]) *
type_convert<AccDataType>(b_scale_thread_buf[I0]);
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) = a_thread_buf
[Number<a_thread_desc_.CalculateOffset(make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) = b_thread_buf
[Number<b_thread_desc_.CalculateOffset(make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run<>(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0));
});
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
c_thread_buf(Number<c_offset>{}) +=
c_thread_buf_per_scale[Number<t>{}] *
type_convert<AccDataType>(a_scale_thread_buf[m0]) *
type_convert<AccDataType>(b_scale_thread_buf[I0]);
});
});
};
@@ -584,57 +580,54 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
if constexpr(TailNum == TailNumber::One)
{
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, k, I0),
a_thread_buf);
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf,
b_thread_desc_,
make_tuple(n0, I0, k, I0),
b_thread_buf);
});
static_ford<Sequence<KRepeat, MRepeat>>{}([&](auto km) {
constexpr auto k = Number<km[Number<0>{}]>{};
constexpr auto m0 = Number<km[Number<1>{}]>{};
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, k, I0),
a_thread_buf);
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf,
b_thread_desc_,
make_tuple(n0, I0, k, I0),
b_thread_buf);
});
});
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
c_thread_buf_per_scale.Clear();
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_ford<Sequence<MRepeat, NRepeat>>{}([&](auto mn) {
constexpr auto m0 = Number<mn[Number<0>{}]>{};
constexpr auto n0 = Number<mn[Number<1>{}]>{};
c_thread_buf_per_scale.Clear();
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0));
});
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
c_thread_buf(Number<c_offset>{}) +=
c_thread_buf_per_scale[Number<t>{}] *
type_convert<AccDataType>(a_scale_thread_buf[m0]) *
type_convert<AccDataType>(b_scale_thread_buf[I0]);
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) = a_thread_buf
[Number<a_thread_desc_.CalculateOffset(make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) = b_thread_buf
[Number<b_thread_desc_.CalculateOffset(make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run<>(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0));
});
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
c_thread_buf(Number<c_offset>{}) +=
c_thread_buf_per_scale[Number<t>{}] *
type_convert<AccDataType>(a_scale_thread_buf[m0]) *
type_convert<AccDataType>(b_scale_thread_buf[I0]);
});
});
}

View File

@@ -264,54 +264,50 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale<BlockGemmPipelineScheduler::Intra
static_for<0, PrefetchStages, 1>{}([&](auto iprefetch) {
// -------------------------------------------------------------------------------------------
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, k, I0),
a_thread_buf);
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(
b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf,
b_thread_desc_,
make_tuple(n0, I0, k, I0),
b_thread_buf);
});
static_ford<Sequence<KRepeat, MRepeat>>{}([&](auto km) {
constexpr auto k = Number<km[Number<0>{}]>{};
constexpr auto m0 = Number<km[Number<1>{}]>{};
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, k, I0),
a_thread_buf);
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf,
b_thread_desc_,
make_tuple(n0, I0, k, I0),
b_thread_buf);
});
});
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_ford<Sequence<KRepeat, MRepeat, NRepeat>>{}([&](auto kmn) {
constexpr auto k0 = Number<kmn[Number<0>{}]>{};
constexpr auto m0 = Number<kmn[Number<1>{}]>{};
constexpr auto n0 = Number<kmn[Number<2>{}]>{};
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
block_sync_lds();
@@ -336,53 +332,48 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale<BlockGemmPipelineScheduler::Intra
auto LoopTailFunc = [&](auto tail_num) {
static_for<1, tail_num, 1>{}([&](auto iprefetch) {
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, k, I0),
a_thread_buf);
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf,
b_thread_desc_,
make_tuple(n0, I0, k, I0),
b_thread_buf);
});
static_ford<Sequence<KRepeat, MRepeat>>{}([&](auto km) {
constexpr auto k = Number<km[Number<0>{}]>{};
constexpr auto m0 = Number<km[Number<1>{}]>{};
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, k, I0),
a_thread_buf);
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf,
b_thread_desc_,
make_tuple(n0, I0, k, I0),
b_thread_buf);
});
});
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_ford<Sequence<KRepeat, MRepeat, NRepeat>>{}([&](auto kmn) {
constexpr auto k0 = Number<kmn[Number<0>{}]>{};
constexpr auto m0 = Number<kmn[Number<1>{}]>{};
constexpr auto n0 = Number<kmn[Number<2>{}]>{};
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) = a_thread_buf
[Number<a_thread_desc_.CalculateOffset(make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) = b_thread_buf
[Number<b_thread_desc_.CalculateOffset(make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
block_sync_lds();
@@ -391,102 +382,94 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale<BlockGemmPipelineScheduler::Intra
});
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, k, I0),
a_thread_buf);
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf,
b_thread_desc_,
make_tuple(n0, I0, k, I0),
b_thread_buf);
});
static_ford<Sequence<KRepeat, MRepeat>>{}([&](auto km) {
constexpr auto k = Number<km[Number<0>{}]>{};
constexpr auto m0 = Number<km[Number<1>{}]>{};
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, k, I0),
a_thread_buf);
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf,
b_thread_desc_,
make_tuple(n0, I0, k, I0),
b_thread_buf);
});
});
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_ford<Sequence<KRepeat, MRepeat, NRepeat>>{}([&](auto kmn) {
constexpr auto k0 = Number<kmn[Number<0>{}]>{};
constexpr auto m0 = Number<kmn[Number<1>{}]>{};
constexpr auto n0 = Number<kmn[Number<2>{}]>{};
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) = a_thread_buf
[Number<a_thread_desc_.CalculateOffset(make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) = b_thread_buf
[Number<b_thread_desc_.CalculateOffset(make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
};
if constexpr(TailNum == TailNumber::One)
{
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, k, I0),
a_thread_buf);
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf,
b_thread_desc_,
make_tuple(n0, I0, k, I0),
b_thread_buf);
});
static_ford<Sequence<KRepeat, MRepeat>>{}([&](auto km) {
constexpr auto k = Number<km[Number<0>{}]>{};
constexpr auto m0 = Number<km[Number<1>{}]>{};
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, k, I0),
a_thread_buf);
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf,
b_thread_desc_,
make_tuple(n0, I0, k, I0),
b_thread_buf);
});
});
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_ford<Sequence<KRepeat, MRepeat, NRepeat>>{}([&](auto kmn) {
constexpr auto k0 = Number<kmn[Number<0>{}]>{};
constexpr auto m0 = Number<kmn[Number<1>{}]>{};
constexpr auto n0 = Number<kmn[Number<2>{}]>{};
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) = a_thread_buf
[Number<a_thread_desc_.CalculateOffset(make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) = b_thread_buf
[Number<b_thread_desc_.CalculateOffset(make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
}
else if constexpr(TailNum == TailNumber::Two)
@@ -823,61 +806,52 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale<BlockGemmPipelineScheduler::Inter
__builtin_amdgcn_sched_barrier(0);
}
static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_ford<Sequence<MRepeat, NRepeat>>{}([&](auto mn) {
constexpr auto m0 = Number<mn[Number<0>{}]>{};
constexpr auto n0 = Number<mn[Number<1>{}]>{};
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, k_ + ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, k_ + ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
// The block_sync_lds() here performs double duty:
// A) safeguard against data hazard because barrier from
// blockwise_gemm is moved here B) reduce VMEM FIFO congestion
// by applying small delays to different wavefronts It is
// performed near the end of MAC cluster to minimize lgkmcnt
// penalty
if constexpr(k0.value == KRepeat - 1 &&
k_.value == KPerInnerLoop - KPack &&
m0.value == MRepeat - 1 && n0.value == NRepeat - 1)
{
__builtin_amdgcn_sched_barrier(0);
block_sync_lds();
__builtin_amdgcn_sched_barrier(0);
}
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0)
{
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_setprio(1);
__builtin_amdgcn_sched_barrier(0);
}
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, k_ + ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, k_ + ik))>{}];
});
// static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t)
// {
// constexpr index_t c_offset =
// c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
// c_thread_buf(Number<c_offset>{}) +=
// c_thread_buf_per_scale[Number<t>{}] *
// type_convert<AccDataType>(b_scale_thread_buf[n0]);
// });
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
// The block_sync_lds() here performs double duty:
// A) safeguard against data hazard because barrier from
// blockwise_gemm is moved here B) reduce VMEM FIFO congestion
// by applying small delays to different wavefronts It is
// performed near the end of MAC cluster to minimize lgkmcnt
// penalty
if constexpr(k0.value == KRepeat - 1 &&
k_.value == KPerInnerLoop - KPack &&
m0.value == MRepeat - 1 && n0.value == NRepeat - 1)
{
__builtin_amdgcn_sched_barrier(0);
block_sync_lds();
__builtin_amdgcn_sched_barrier(0);
}
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0)
{
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_setprio(1);
__builtin_amdgcn_sched_barrier(0);
}
});
});
__builtin_amdgcn_sched_barrier(0);
@@ -944,54 +918,46 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale<BlockGemmPipelineScheduler::Inter
__builtin_amdgcn_sched_barrier(0);
}
static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_ford<Sequence<MRepeat, NRepeat>>{}([&](auto mn) {
constexpr auto m0 = Number<mn[Number<0>{}]>{};
constexpr auto n0 = Number<mn[Number<1>{}]>{};
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, k_ + ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, k_ + ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
if constexpr(k0.value == KRepeat - 1 &&
k_.value == KPerInnerLoop - KPack &&
m0.value == MRepeat - 1 && n0.value == NRepeat - 1)
{
__builtin_amdgcn_sched_barrier(0);
block_sync_lds();
__builtin_amdgcn_sched_barrier(0);
}
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0)
{
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_setprio(1);
__builtin_amdgcn_sched_barrier(0);
}
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, k_ + ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, k_ + ik))>{}];
});
// static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
// constexpr index_t c_offset =
// c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
// c_thread_buf(Number<c_offset>{}) +=
// c_thread_buf_per_scale[Number<t>{}] *
// type_convert<AccDataType>(b_scale_thread_buf[n0]);
// });
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
if constexpr(k0.value == KRepeat - 1 &&
k_.value == KPerInnerLoop - KPack &&
m0.value == MRepeat - 1 && n0.value == NRepeat - 1)
{
__builtin_amdgcn_sched_barrier(0);
block_sync_lds();
__builtin_amdgcn_sched_barrier(0);
}
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0)
{
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_setprio(1);
__builtin_amdgcn_sched_barrier(0);
}
});
});
__builtin_amdgcn_sched_barrier(0);
@@ -1041,54 +1007,43 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale<BlockGemmPipelineScheduler::Inter
__builtin_amdgcn_sched_barrier(0);
}
static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_ford<Sequence<MRepeat, NRepeat>>{}([&](auto mn) {
constexpr auto m0 = Number<mn[Number<0>{}]>{};
constexpr auto n0 = Number<mn[Number<1>{}]>{};
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, k_ + ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, k_ + ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
if constexpr(k0.value == KRepeat - 1 &&
k_.value == KPerInnerLoop - KPack &&
m0.value == MRepeat - 1 && n0.value == NRepeat - 1)
{
__builtin_amdgcn_sched_barrier(0);
block_sync_lds();
__builtin_amdgcn_sched_barrier(0);
}
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0)
{
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_setprio(1);
__builtin_amdgcn_sched_barrier(0);
}
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, k_ + ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, k_ + ik))>{}];
});
// static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
// constexpr index_t c_offset =
// c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
// c_thread_buf(Number<c_offset>{}) +=
// c_thread_buf_per_scale[Number<t>{}] *
// type_convert<AccDataType>(b_scale_thread_buf[n0]);
// });
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
if constexpr(k0.value == KRepeat - 1 && k_.value == KPerInnerLoop - KPack &&
m0.value == MRepeat - 1 && n0.value == NRepeat - 1)
{
__builtin_amdgcn_sched_barrier(0);
block_sync_lds();
__builtin_amdgcn_sched_barrier(0);
}
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0)
{
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_setprio(1);
__builtin_amdgcn_sched_barrier(0);
}
});
});
__builtin_amdgcn_sched_barrier(0);
@@ -1125,54 +1080,43 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale<BlockGemmPipelineScheduler::Inter
__builtin_amdgcn_sched_barrier(0);
}
static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_ford<Sequence<MRepeat, NRepeat>>{}([&](auto mn) {
constexpr auto m0 = Number<mn[Number<0>{}]>{};
constexpr auto n0 = Number<mn[Number<1>{}]>{};
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, k_ + ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, k_ + ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
if constexpr(k0.value == KRepeat - 1 &&
k_.value == KPerInnerLoop - KPack &&
m0.value == MRepeat - 1 && n0.value == NRepeat - 1)
{
__builtin_amdgcn_sched_barrier(0);
block_sync_lds();
__builtin_amdgcn_sched_barrier(0);
}
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0)
{
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_setprio(1);
__builtin_amdgcn_sched_barrier(0);
}
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, k_ + ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, k_ + ik))>{}];
});
// static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
// constexpr index_t c_offset =
// c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
// c_thread_buf(Number<c_offset>{}) +=
// c_thread_buf_per_scale[Number<t>{}] *
// type_convert<AccDataType>(b_scale_thread_buf[n0]);
// });
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
if constexpr(k0.value == KRepeat - 1 && k_.value == KPerInnerLoop - KPack &&
m0.value == MRepeat - 1 && n0.value == NRepeat - 1)
{
__builtin_amdgcn_sched_barrier(0);
block_sync_lds();
__builtin_amdgcn_sched_barrier(0);
}
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0)
{
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_setprio(1);
__builtin_amdgcn_sched_barrier(0);
}
});
});
__builtin_amdgcn_sched_barrier(0);

View File

@@ -363,34 +363,29 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_ford<Sequence<KRepeat, MRepeat, NRepeat>>{}([&](auto kmn) {
constexpr auto k0 = Number<kmn[Number<0>{}]>{};
constexpr auto m0 = Number<kmn[Number<1>{}]>{};
constexpr auto n0 = Number<kmn[Number<2>{}]>{};
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) = a_thread_buf
[Number<a_thread_desc_.CalculateOffset(make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) = b_thread_buf
[Number<b_thread_desc_.CalculateOffset(make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
block_sync_lds();
@@ -423,32 +418,28 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
// tail
if constexpr(TailNum == TailNumber::Full)
{
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_ford<Sequence<KRepeat, MRepeat, NRepeat>>{}([&](auto kmn) {
constexpr auto k0 = Number<kmn[Number<0>{}]>{};
constexpr auto m0 = Number<kmn[Number<1>{}]>{};
constexpr auto n0 = Number<kmn[Number<2>{}]>{};
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) = a_thread_buf
[Number<a_thread_desc_.CalculateOffset(make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) = b_thread_buf
[Number<b_thread_desc_.CalculateOffset(make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
// Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle
// latency

View File

@@ -471,42 +471,41 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
static_ford<Sequence<MRepeat, NRepeat>>{}([&](auto mn) {
constexpr auto m0 = Number<mn[Number<0>{}]>{};
constexpr auto n0 = Number<mn[Number<1>{}]>{};
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
.template AsType<AccDataType>()(Number<t>{}) = 0;
});
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
});
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
c_thread_buf(Number<c_offset>{}) +=
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
.template AsType<AccDataType>()(Number<t>{}) = 0;
});
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf,
xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
});
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
c_thread_buf(Number<c_offset>{}) +=
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
.template AsType<AccDataType>()[Number<t>{}] *
type_convert<AccDataType>(c_scale_thread_buf[m0]);
});
.template AsType<AccDataType>()[Number<t>{}] *
type_convert<AccDataType>(c_scale_thread_buf[m0]);
});
});
@@ -573,41 +572,39 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
// tail
if constexpr(TailNum == TailNumber::Full)
{
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
static_ford<Sequence<MRepeat, NRepeat>>{}([&](auto mn) {
constexpr auto m0 = Number<mn[Number<0>{}]>{};
constexpr auto n0 = Number<mn[Number<1>{}]>{};
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
.template AsType<AccDataType>()(Number<t>{}) = 0;
});
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) = a_thread_buf
[Number<a_thread_desc_.CalculateOffset(make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) = b_thread_buf
[Number<b_thread_desc_.CalculateOffset(make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
});
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
c_thread_buf(Number<c_offset>{}) +=
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
.template AsType<AccDataType>()(Number<t>{}) = 0;
});
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
});
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
c_thread_buf(Number<c_offset>{}) +=
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
.template AsType<AccDataType>()[Number<t>{}] *
type_convert<AccDataType>(c_scale_thread_buf[m0]);
});
.template AsType<AccDataType>()[Number<t>{}] *
type_convert<AccDataType>(c_scale_thread_buf[m0]);
});
});
__builtin_amdgcn_sched_barrier(0);

View File

@@ -428,34 +428,29 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
b_scale_grid_desc, b_scale_thread_copy_step.At(Number<1>{}));
}
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_ford<Sequence<KRepeat, MRepeat, NRepeat>>{}([&](auto kmn) {
constexpr auto k0 = Number<kmn[Number<0>{}]>{};
constexpr auto m0 = Number<kmn[Number<1>{}]>{};
constexpr auto n0 = Number<kmn[Number<2>{}]>{};
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) = a_thread_buf
[Number<a_thread_desc_.CalculateOffset(make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) = b_thread_buf
[Number<b_thread_desc_.CalculateOffset(make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
block_sync_lds();
@@ -490,32 +485,28 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
// tail
if constexpr(TailNum == TailNumber::Full)
{
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_ford<Sequence<KRepeat, MRepeat, NRepeat>>{}([&](auto kmn) {
constexpr auto k0 = Number<kmn[Number<0>{}]>{};
constexpr auto m0 = Number<kmn[Number<1>{}]>{};
constexpr auto n0 = Number<kmn[Number<2>{}]>{};
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) = a_thread_buf
[Number<a_thread_desc_.CalculateOffset(make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) = b_thread_buf
[Number<b_thread_desc_.CalculateOffset(make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
__builtin_amdgcn_sched_barrier(0);
}

View File

@@ -459,49 +459,51 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx<BlockGemmPipelineScheduler::Intrawave,
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) {
constexpr auto k_step = k * xdlops_gemm.KPerXdlops * KPack / xdlops_gemm.K1PerXdlops;
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}(
[&](auto chunk) {
constexpr auto a_k_step_chunk =
k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k,
make_tuple(Number<m0 / MXdlPack>{},
I0,
Number<m0 % MXdlPack>{},
I0,
Number<a_k_step_chunk>{}),
a_block_bufs(I0),
a_thread_desc_,
make_tuple(Number<m0 / MXdlPack>{},
I0,
Number<m0 % MXdlPack>{},
k,
Number<chunk * KThreadChunk>{}),
a_thread_buf);
});
});
static_for<0, NRepeat, 1>{}([&](auto n0) {
// read block data in chunks to assemble correct thread vectors
static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}(
[&](auto chunk) {
constexpr auto b_k_step_chunk =
k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k,
make_tuple(Number<n0 / NXdlPack>{},
I0,
Number<n0 % NXdlPack>{},
I0,
Number<b_k_step_chunk>{}),
b_block_bufs(I0),
b_thread_desc_,
make_tuple(Number<n0 / NXdlPack>{},
I0,
Number<n0 % NXdlPack>{},
k,
Number<chunk * KThreadChunk>{}),
b_thread_buf);
});
});
static_ford<
Sequence<MRepeat, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk)>>{}(
[&](auto mc) {
constexpr auto m0 = Number<mc[Number<0>{}]>{};
constexpr auto chunk = Number<mc[Number<1>{}]>{};
constexpr auto a_k_step_chunk =
k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k,
make_tuple(Number<m0 / MXdlPack>{},
I0,
Number<m0 % MXdlPack>{},
I0,
Number<a_k_step_chunk>{}),
a_block_bufs(I0),
a_thread_desc_,
make_tuple(Number<m0 / MXdlPack>{},
I0,
Number<m0 % MXdlPack>{},
k,
Number<chunk * KThreadChunk>{}),
a_thread_buf);
});
// read block data in chunks to assemble correct thread vectors
static_ford<
Sequence<NRepeat, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk)>>{}(
[&](auto nc) {
constexpr auto n0 = Number<nc[Number<0>{}]>{};
constexpr auto chunk = Number<nc[Number<1>{}]>{};
constexpr auto b_k_step_chunk =
k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k,
make_tuple(Number<n0 / NXdlPack>{},
I0,
Number<n0 % NXdlPack>{},
I0,
Number<b_k_step_chunk>{}),
b_block_bufs(I0),
b_thread_desc_,
make_tuple(Number<n0 / NXdlPack>{},
I0,
Number<n0 % NXdlPack>{},
k,
Number<chunk * KThreadChunk>{}),
b_thread_buf);
});
});
// Global prefetch 2
@@ -577,91 +579,85 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx<BlockGemmPipelineScheduler::Intrawave,
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
constexpr index_t b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
static_ford<
Sequence<MRepeat / MXdlPack, NRepeat / NXdlPack, KRepeat / KXdlPack>>{}(
[&](auto mnk) {
constexpr auto m0 = mnk[Number<0>{}];
constexpr auto n0 = mnk[Number<1>{}];
constexpr auto k0 = mnk[Number<2>{}];
static_assert(0 < ScalesPerXdlopsRunPerThread,
"Must have at least one scale per Xdlops "
"per Thread.");
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
constexpr index_t b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
vector_type<AScaleDataType, a_scale_thread_vec_size>
a_scale_thread_vec;
vector_type<BScaleDataType, b_scale_thread_vec_size>
b_scale_thread_vec;
static_assert(0 < ScalesPerXdlopsRunPerThread,
"Must have at least one scale per Xdlops "
"per Thread.");
// Pack scale_thread_buf into scale_thread_vec
static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) {
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
a_scale_thread_bufs(
scale_comp_buf)[Number<a_scale_offset + s>{}];
vector_type<AScaleDataType, a_scale_thread_vec_size> a_scale_thread_vec;
vector_type<BScaleDataType, b_scale_thread_vec_size> b_scale_thread_vec;
// Pack scale_thread_buf into scale_thread_vec
static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) {
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
a_scale_thread_bufs(
scale_comp_buf)[Number<a_scale_offset + s>{}];
});
static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) {
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
b_scale_thread_bufs(
scale_comp_buf)[Number<b_scale_offset + s>{}];
});
static_ford<Sequence<KXdlPack, MXdlPack, NXdlPack>>{}([&](auto
kmn_xdl) {
constexpr auto ikxdl = Number<kmn_xdl[Number<0>{}]>{};
constexpr auto imxdl = Number<kmn_xdl[Number<1>{}]>{};
constexpr auto inxdl = Number<kmn_xdl[Number<2>{}]>{};
constexpr auto kxdl = ikxdl + k0 * KXdlPack;
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, imxdl, kxdl, ik))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, inxdl, kxdl, ik))>{}];
});
static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) {
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
b_scale_thread_bufs(
scale_comp_buf)[Number<b_scale_offset + s>{}];
});
using mfma_input_type_a = typename vector_type< //
ComputeTypeA,
xdlops_gemm.K1PerXdlops / APackedSize>::type;
static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
constexpr auto kxdl = ikxdl + k0 * KXdlPack;
using mfma_input_type_b = typename vector_type< //
ComputeTypeB,
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
using mfma_scale_input_type_a = typename vector_type< //
AScaleDataType,
a_scale_thread_vec_size>::type;
using mfma_scale_input_type_b = typename vector_type< //
BScaleDataType,
b_scale_thread_vec_size>::type;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(
ik) = a_thread_buf
[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, imxdl, kxdl, ik))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(
ik) = b_thread_buf
[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, inxdl, kxdl, ik))>{}];
});
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, imxdl, inxdl, 0));
using mfma_input_type_a = typename vector_type< //
ComputeTypeA,
xdlops_gemm.K1PerXdlops / APackedSize>::type;
using mfma_input_type_b = typename vector_type< //
ComputeTypeB,
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
using mfma_scale_input_type_a = typename vector_type< //
AScaleDataType,
a_scale_thread_vec_size>::type;
using mfma_scale_input_type_b = typename vector_type< //
BScaleDataType,
b_scale_thread_vec_size>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, imxdl, inxdl, 0));
// MFMA accumulation
xdlops_gemm.template Run<ikxdl * MXdlPack + imxdl,
ikxdl * NXdlPack + inxdl>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec
.template AsType<mfma_scale_input_type_a>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec
.template AsType<mfma_scale_input_type_b>(),
c_thread_buf.GetVectorTypeReference(
Number<c_offset>{}));
});
});
});
// MFMA accumulation
xdlops_gemm.template Run<ikxdl * MXdlPack + imxdl,
ikxdl * NXdlPack + inxdl>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec.template AsType<mfma_scale_input_type_a>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec.template AsType<mfma_scale_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
// k indexes mapping to threads for 32x32x64:
// t0 : |0 --> 15 32 --> 47 | 64 --> 79 96 --> 111 | etc.
@@ -774,83 +770,81 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx<BlockGemmPipelineScheduler::Intrawave,
b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0));
});
static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
constexpr index_t b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
static_ford<Sequence<MRepeat / MXdlPack, NRepeat / NXdlPack, KRepeat / KXdlPack>>{}(
[&](auto mnk) {
constexpr auto m0 = mnk[Number<0>{}];
constexpr auto n0 = mnk[Number<1>{}];
constexpr auto k0 = mnk[Number<2>{}];
static_assert(0 < ScalesPerXdlopsRunPerThread,
"Must have at least one scale per Xdlops "
"per Thread.");
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
constexpr index_t b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
vector_type<AScaleDataType, a_scale_thread_vec_size> a_scale_thread_vec;
vector_type<BScaleDataType, b_scale_thread_vec_size> b_scale_thread_vec;
static_assert(0 < ScalesPerXdlopsRunPerThread,
"Must have at least one scale per Xdlops "
"per Thread.");
// Pack scale_thread_buf into scale_thread_vec
static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) {
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
a_scale_thread_bufs(I0)[Number<a_scale_offset + s>{}];
vector_type<AScaleDataType, a_scale_thread_vec_size> a_scale_thread_vec;
vector_type<BScaleDataType, b_scale_thread_vec_size> b_scale_thread_vec;
// Pack scale_thread_buf into scale_thread_vec
static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) {
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
a_scale_thread_bufs(I0)[Number<a_scale_offset + s>{}];
});
static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) {
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
b_scale_thread_bufs(I0)[Number<b_scale_offset + s>{}];
});
static_ford<Sequence<KXdlPack, MXdlPack, NXdlPack>>{}([&](auto kmn_xdl) {
constexpr auto ikxdl = Number<kmn_xdl[Number<0>{}]>{};
constexpr auto imxdl = Number<kmn_xdl[Number<1>{}]>{};
constexpr auto inxdl = Number<kmn_xdl[Number<2>{}]>{};
constexpr auto kxdl = ikxdl + k0 * KXdlPack;
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, imxdl, kxdl, ik))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, inxdl, kxdl, ik))>{}];
});
static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) {
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
b_scale_thread_bufs(I0)[Number<b_scale_offset + s>{}];
});
using mfma_input_type_a = typename vector_type< //
ComputeTypeA,
xdlops_gemm.K1PerXdlops / APackedSize>::type;
static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
constexpr auto kxdl = ikxdl + k0 * KXdlPack;
using mfma_input_type_b = typename vector_type< //
ComputeTypeB,
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
using mfma_scale_input_type_a = typename vector_type< //
AScaleDataType,
a_scale_thread_vec_size>::type;
using mfma_scale_input_type_b = typename vector_type< //
BScaleDataType,
b_scale_thread_vec_size>::type;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, imxdl, kxdl, ik))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, inxdl, kxdl, ik))>{}];
});
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, imxdl, inxdl, 0));
using mfma_input_type_a = typename vector_type< //
ComputeTypeA,
xdlops_gemm.K1PerXdlops / APackedSize>::type;
using mfma_input_type_b = typename vector_type< //
ComputeTypeB,
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
using mfma_scale_input_type_a = typename vector_type< //
AScaleDataType,
a_scale_thread_vec_size>::type;
using mfma_scale_input_type_b = typename vector_type< //
BScaleDataType,
b_scale_thread_vec_size>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, imxdl, inxdl, 0));
// MFMA accumulation
xdlops_gemm.template Run<ikxdl * MXdlPack + imxdl,
ikxdl * NXdlPack + inxdl>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec
.template AsType<mfma_scale_input_type_a>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec
.template AsType<mfma_scale_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
// MFMA accumulation
xdlops_gemm
.template Run<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec.template AsType<mfma_scale_input_type_a>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec.template AsType<mfma_scale_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
__builtin_amdgcn_s_waitcnt(3952);
block_sync_lds();
@@ -858,210 +852,206 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx<BlockGemmPipelineScheduler::Intrawave,
static_for<0, KRepeat, 1>{}([&](auto k) {
constexpr auto k_step =
k * xdlops_gemm.KPerXdlops * KPack / xdlops_gemm.K1PerXdlops;
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}(
[&](auto chunk) {
constexpr auto a_k_step_chunk =
k_step +
chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k,
make_tuple(Number<m0 / MXdlPack>{},
I0,
Number<m0 % MXdlPack>{},
I0,
Number<a_k_step_chunk>{}),
a_block_bufs(I1),
a_thread_desc_,
make_tuple(Number<m0 / MXdlPack>{},
I0,
Number<m0 % MXdlPack>{},
k,
Number<chunk * KThreadChunk>{}),
a_thread_buf);
});
});
static_for<0, NRepeat, 1>{}([&](auto n0) {
// read block data in chunks to assemble correct thread vectors
static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}(
[&](auto chunk) {
constexpr auto b_k_step_chunk =
k_step +
chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k,
make_tuple(Number<n0 / NXdlPack>{},
I0,
Number<n0 % NXdlPack>{},
I0,
Number<b_k_step_chunk>{}),
b_block_bufs(I1),
b_thread_desc_,
make_tuple(Number<n0 / NXdlPack>{},
I0,
Number<n0 % NXdlPack>{},
k,
Number<chunk * KThreadChunk>{}),
b_thread_buf);
});
});
static_ford<
Sequence<MRepeat, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk)>>{}(
[&](auto mc) {
constexpr auto m0 = Number<mc[Number<0>{}]>{};
constexpr auto chunk = Number<mc[Number<1>{}]>{};
constexpr auto a_k_step_chunk =
k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k,
make_tuple(Number<m0 / MXdlPack>{},
I0,
Number<m0 % MXdlPack>{},
I0,
Number<a_k_step_chunk>{}),
a_block_bufs(I1),
a_thread_desc_,
make_tuple(Number<m0 / MXdlPack>{},
I0,
Number<m0 % MXdlPack>{},
k,
Number<chunk * KThreadChunk>{}),
a_thread_buf);
});
// read block data in chunks to assemble correct thread vectors
static_ford<
Sequence<NRepeat, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk)>>{}(
[&](auto nc) {
constexpr auto n0 = Number<nc[Number<0>{}]>{};
constexpr auto chunk = Number<nc[Number<1>{}]>{};
constexpr auto b_k_step_chunk =
k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k,
make_tuple(Number<n0 / NXdlPack>{},
I0,
Number<n0 % NXdlPack>{},
I0,
Number<b_k_step_chunk>{}),
b_block_bufs(I1),
b_thread_desc_,
make_tuple(Number<n0 / NXdlPack>{},
I0,
Number<n0 % NXdlPack>{},
k,
Number<chunk * KThreadChunk>{}),
b_thread_buf);
});
});
static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
constexpr index_t b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
static_ford<Sequence<MRepeat / MXdlPack, NRepeat / NXdlPack, KRepeat / KXdlPack>>{}(
[&](auto mnk) {
constexpr auto m0 = mnk[Number<0>{}];
constexpr auto n0 = mnk[Number<1>{}];
constexpr auto k0 = mnk[Number<2>{}];
static_assert(0 < ScalesPerXdlopsRunPerThread,
"Must have at least one scale per Xdlops "
"per Thread.");
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
constexpr index_t b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
vector_type<AScaleDataType, a_scale_thread_vec_size> a_scale_thread_vec;
vector_type<BScaleDataType, b_scale_thread_vec_size> b_scale_thread_vec;
static_assert(0 < ScalesPerXdlopsRunPerThread,
"Must have at least one scale per Xdlops "
"per Thread.");
// Pack scale_thread_buf into scale_thread_vec
static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) {
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
a_scale_thread_bufs(I1)[Number<a_scale_offset + s>{}];
vector_type<AScaleDataType, a_scale_thread_vec_size> a_scale_thread_vec;
vector_type<BScaleDataType, b_scale_thread_vec_size> b_scale_thread_vec;
// Pack scale_thread_buf into scale_thread_vec
static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) {
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
a_scale_thread_bufs(I1)[Number<a_scale_offset + s>{}];
});
static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) {
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
b_scale_thread_bufs(I1)[Number<b_scale_offset + s>{}];
});
static_ford<Sequence<KXdlPack, MXdlPack, NXdlPack>>{}([&](auto kmn_xdl) {
constexpr auto ikxdl = Number<kmn_xdl[Number<0>{}]>{};
constexpr auto imxdl = Number<kmn_xdl[Number<1>{}]>{};
constexpr auto inxdl = Number<kmn_xdl[Number<2>{}]>{};
constexpr auto kxdl = ikxdl + k0 * KXdlPack;
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, imxdl, kxdl, ik))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, inxdl, kxdl, ik))>{}];
});
static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) {
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
b_scale_thread_bufs(I1)[Number<b_scale_offset + s>{}];
});
using mfma_input_type_a = typename vector_type< //
ComputeTypeA,
xdlops_gemm.K1PerXdlops / APackedSize>::type;
static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
constexpr auto kxdl = ikxdl + k0 * KXdlPack;
using mfma_input_type_b = typename vector_type< //
ComputeTypeB,
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
using mfma_scale_input_type_a = typename vector_type< //
AScaleDataType,
a_scale_thread_vec_size>::type;
using mfma_scale_input_type_b = typename vector_type< //
BScaleDataType,
b_scale_thread_vec_size>::type;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, imxdl, kxdl, ik))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, inxdl, kxdl, ik))>{}];
});
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, imxdl, inxdl, 0));
using mfma_input_type_a = typename vector_type< //
ComputeTypeA,
xdlops_gemm.K1PerXdlops / APackedSize>::type;
using mfma_input_type_b = typename vector_type< //
ComputeTypeB,
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
using mfma_scale_input_type_a = typename vector_type< //
AScaleDataType,
a_scale_thread_vec_size>::type;
using mfma_scale_input_type_b = typename vector_type< //
BScaleDataType,
b_scale_thread_vec_size>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, imxdl, inxdl, 0));
// MFMA accumulation
xdlops_gemm.template Run<ikxdl * MXdlPack + imxdl,
ikxdl * NXdlPack + inxdl>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec
.template AsType<mfma_scale_input_type_a>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec
.template AsType<mfma_scale_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
// MFMA accumulation
xdlops_gemm
.template Run<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec.template AsType<mfma_scale_input_type_a>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec.template AsType<mfma_scale_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
}
else if constexpr(TailNum == TailNumber::Odd)
{
static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
constexpr index_t b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
static_ford<Sequence<MRepeat / MXdlPack, NRepeat / NXdlPack, KRepeat / KXdlPack>>{}(
[&](auto mnk) {
constexpr auto m0 = mnk[Number<0>{}];
constexpr auto n0 = mnk[Number<1>{}];
constexpr auto k0 = mnk[Number<2>{}];
static_assert(0 < ScalesPerXdlopsRunPerThread,
"Must have at least one scale per Xdlops "
"per Thread.");
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
constexpr index_t b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
vector_type<AScaleDataType, a_scale_thread_vec_size> a_scale_thread_vec;
vector_type<BScaleDataType, b_scale_thread_vec_size> b_scale_thread_vec;
static_assert(0 < ScalesPerXdlopsRunPerThread,
"Must have at least one scale per Xdlops "
"per Thread.");
// Pack scale_thread_buf into scale_thread_vec
static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) {
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
a_scale_thread_bufs(I0)[Number<a_scale_offset + s>{}];
vector_type<AScaleDataType, a_scale_thread_vec_size> a_scale_thread_vec;
vector_type<BScaleDataType, b_scale_thread_vec_size> b_scale_thread_vec;
// Pack scale_thread_buf into scale_thread_vec
static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) {
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
a_scale_thread_bufs(I0)[Number<a_scale_offset + s>{}];
});
static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) {
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
b_scale_thread_bufs(I0)[Number<b_scale_offset + s>{}];
});
static_ford<Sequence<KXdlPack, MXdlPack, NXdlPack>>{}([&](auto kmn_xdl) {
constexpr auto ikxdl = Number<kmn_xdl[Number<0>{}]>{};
constexpr auto imxdl = Number<kmn_xdl[Number<1>{}]>{};
constexpr auto inxdl = Number<kmn_xdl[Number<2>{}]>{};
constexpr auto kxdl = ikxdl + k0 * KXdlPack;
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, imxdl, kxdl, ik))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, inxdl, kxdl, ik))>{}];
});
static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) {
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
b_scale_thread_bufs(I0)[Number<b_scale_offset + s>{}];
});
using mfma_input_type_a = typename vector_type< //
ComputeTypeA,
xdlops_gemm.K1PerXdlops / APackedSize>::type;
static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
constexpr auto kxdl = ikxdl + k0 * KXdlPack;
using mfma_input_type_b = typename vector_type< //
ComputeTypeB,
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
using mfma_scale_input_type_a = typename vector_type< //
AScaleDataType,
a_scale_thread_vec_size>::type;
using mfma_scale_input_type_b = typename vector_type< //
BScaleDataType,
b_scale_thread_vec_size>::type;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, imxdl, kxdl, ik))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, inxdl, kxdl, ik))>{}];
});
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, imxdl, inxdl, 0));
using mfma_input_type_a = typename vector_type< //
ComputeTypeA,
xdlops_gemm.K1PerXdlops / APackedSize>::type;
using mfma_input_type_b = typename vector_type< //
ComputeTypeB,
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
using mfma_scale_input_type_a = typename vector_type< //
AScaleDataType,
a_scale_thread_vec_size>::type;
using mfma_scale_input_type_b = typename vector_type< //
BScaleDataType,
b_scale_thread_vec_size>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, imxdl, inxdl, 0));
// MFMA accumulation
xdlops_gemm.template Run<ikxdl * MXdlPack + imxdl,
ikxdl * NXdlPack + inxdl>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec
.template AsType<mfma_scale_input_type_a>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec
.template AsType<mfma_scale_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
// MFMA accumulation
xdlops_gemm
.template Run<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec.template AsType<mfma_scale_input_type_a>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec.template AsType<mfma_scale_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
}
}

View File

@@ -261,54 +261,49 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx_bprehuffle<BlockGemmPipelineScheduler:
// Stage 1
// global read more
static_for<0, buffer_load_stages_more, 1>{}([&](auto /*i*/) {
static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
static_ford<Sequence<buffer_load_stages_more, num_mfma_perstage>>{}([&](auto ii) {
constexpr auto imfma = Number<ii[Number<1>{}]>{};
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
if constexpr(imfma % buffer_load_issue_point_interval_more == 0)
{
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
if constexpr(imfma % buffer_load_issue_point_interval_more == 0)
{
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
{
__builtin_amdgcn_sched_group_barrier(
0x100, ds_read_a_mfma_rate, 0); // DS read
}
});
if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
{
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
}
});
// global read less
static_for<0, (num_total_stages - 2 - buffer_load_stages_more), 1>{}([&](auto /*i*/) {
static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
if constexpr(imfma % buffer_load_issue_point_interval_less == 0)
{
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
{
__builtin_amdgcn_sched_group_barrier(
0x100, ds_read_a_mfma_rate, 0); // DS read
}
});
static_ford<Sequence<(num_total_stages - 2 - buffer_load_stages_more),
num_mfma_perstage>>{}([&](auto ii) {
constexpr auto imfma = Number<ii[Number<1>{}]>{};
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
if constexpr(imfma % buffer_load_issue_point_interval_less == 0)
{
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
{
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
}
});
// Stage 2, Sync
// lds synchronization, prefetch next loop local A
static_for<0, num_ds_read_a_prefetch_stages, 1>{}([&](auto /*i*/) {
static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
if constexpr(imfma % buffer_load_issue_point_interval_stage2 == 0)
{
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
{
__builtin_amdgcn_sched_group_barrier(
0x100, ds_read_a_mfma_rate, 0); // DS read
}
});
static_ford<Sequence<num_ds_read_a_prefetch_stages, num_mfma_perstage>>{}([&](auto ii) {
constexpr auto imfma = Number<ii[Number<1>{}]>{};
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
if constexpr(imfma % buffer_load_issue_point_interval_stage2 == 0)
{
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
{
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
}
});
}
else
@@ -537,25 +532,24 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx_bprehuffle<BlockGemmPipelineScheduler:
// Local prefetch 1, sync the async load
__builtin_amdgcn_s_waitcnt(async_vmcnt_encoding);
block_sync_lds();
static_for<0, LocalPrefetchStages, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k) {
constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize *
(APackedSize * KPack / xdlops_gemm.K1PerXdlops);
static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}(
[&](auto chunk) {
constexpr auto a_k_step_chunk =
k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_m3_k,
make_tuple(
I0, I0, Number<m0 % MXdlPack>{}, I0, Number<a_k_step_chunk>{}),
a_block_bufs(I0),
a_thread_desc_,
make_tuple(
I0, I0, Number<m0 % MXdlPack>{}, k, Number<chunk * KThreadChunk>{}),
a_thread_buf);
});
});
static_ford<Sequence<LocalPrefetchStages, KRepeat>>{}([&](auto mk) {
constexpr auto m0 = Number<mk[Number<0>{}]>{};
constexpr auto k = Number<mk[Number<1>{}]>{};
constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize *
(APackedSize * KPack / xdlops_gemm.K1PerXdlops);
static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}(
[&](auto chunk) {
constexpr auto a_k_step_chunk =
k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_m3_k,
make_tuple(I0, I0, Number<m0 % MXdlPack>{}, I0, Number<a_k_step_chunk>{}),
a_block_bufs(I0),
a_thread_desc_,
make_tuple(
I0, I0, Number<m0 % MXdlPack>{}, k, Number<chunk * KThreadChunk>{}),
a_thread_buf);
});
});
// Global prefetch 2

View File

@@ -368,79 +368,10 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_bufs[mfma_reg_buf]
[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_bufs[mfma_reg_buf]
[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
HotLoopScheduler();
};
LoopFunc(I1, I1, I0, I0);
LoopFunc(I0, I0, I1, I1);
i += HotloopUnroll;
} while(i < (num_loop - PrefetchStages));
}
auto ReadWriteCompFunc = [&](auto lds_read_buf,
auto lds_read_reg_buf,
auto lds_write_buf,
auto mfma_reg_buf) {
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf.At(lds_read_buf),
a_thread_desc_,
make_tuple(m0, I0, k, I0),
a_thread_bufs(lds_read_reg_buf));
});
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf.At(lds_read_buf),
b_thread_desc_,
make_tuple(n0, I0, k, I0),
b_thread_bufs(lds_read_reg_buf));
});
});
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(lds_write_buf));
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(lds_write_buf));
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_ford<Sequence<KRepeat, MRepeat, NRepeat>>{}([&](auto kmn) {
constexpr auto k0 = Number<kmn[Number<0>{}]>{};
constexpr auto m0 = Number<kmn[Number<1>{}]>{};
constexpr auto n0 = Number<kmn[Number<2>{}]>{};
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
@@ -463,11 +394,72 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
HotLoopScheduler();
};
HotLoopScheduler();
};
LoopFunc(I1, I1, I0, I0);
LoopFunc(I0, I0, I1, I1);
i += HotloopUnroll;
} while(i < (num_loop - PrefetchStages));
}
auto ReadWriteCompFunc =
[&](auto lds_read_buf, auto lds_read_reg_buf, auto lds_write_buf, auto mfma_reg_buf) {
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf.At(lds_read_buf),
a_thread_desc_,
make_tuple(m0, I0, k, I0),
a_thread_bufs(lds_read_reg_buf));
});
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf.At(lds_read_buf),
b_thread_desc_,
make_tuple(n0, I0, k, I0),
b_thread_bufs(lds_read_reg_buf));
});
});
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(lds_write_buf));
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(lds_write_buf));
static_ford<Sequence<KRepeat, MRepeat, NRepeat>>{}([&](auto kmn) {
constexpr auto k0 = Number<kmn[Number<0>{}]>{};
constexpr auto m0 = Number<kmn[Number<1>{}]>{};
constexpr auto n0 = Number<kmn[Number<2>{}]>{};
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_bufs[mfma_reg_buf][Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_bufs[mfma_reg_buf][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
HotLoopScheduler();
};
auto ReadCompFunc = [&](auto lds_read_buf, auto lds_read_reg_buf, auto mfma_reg_buf) {
block_sync_lds();
@@ -491,64 +483,60 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
});
});
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_ford<Sequence<KRepeat, MRepeat, NRepeat>>{}([&](auto kmn) {
constexpr auto k0 = Number<kmn[Number<0>{}]>{};
constexpr auto m0 = Number<kmn[Number<1>{}]>{};
constexpr auto n0 = Number<kmn[Number<2>{}]>{};
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_bufs[mfma_reg_buf][Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_bufs[mfma_reg_buf][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_bufs[mfma_reg_buf][Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_bufs[mfma_reg_buf][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
HotLoopScheduler();
};
auto CompFunc = [&](auto mfma_reg_buf) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_ford<Sequence<KRepeat, MRepeat, NRepeat>>{}([&](auto kmn) {
constexpr auto k0 = Number<kmn[Number<0>{}]>{};
constexpr auto m0 = Number<kmn[Number<1>{}]>{};
constexpr auto n0 = Number<kmn[Number<2>{}]>{};
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_bufs[mfma_reg_buf][Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_bufs[mfma_reg_buf][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_bufs[mfma_reg_buf][Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_bufs[mfma_reg_buf][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
};
// tail
@@ -918,81 +906,10 @@ struct BlockwiseGemmXdlopsDirectLoad_pipeline_v4<BlockGemmPipelineScheduler::Int
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_bufs[mfma_reg_buf]
[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_bufs[mfma_reg_buf]
[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
HotLoopScheduler();
};
LoopFunc(I1, I1, I0, I0);
LoopFunc(I0, I0, I1, I1);
i += HotloopUnroll;
} while(i < (num_loop - PrefetchStages));
}
auto ReadWriteCompFunc = [&](auto lds_read_buf,
auto lds_read_reg_buf,
auto lds_write_buf,
auto mfma_reg_buf) {
block_sync_lds_direct_load();
static_for<0, KRepeat, 1>{}([&](auto k) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf.At(lds_read_buf),
a_thread_desc_,
make_tuple(m0, I0, k, I0),
a_thread_bufs(lds_read_reg_buf));
});
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf.At(lds_read_buf),
b_thread_desc_,
make_tuple(n0, I0, k, I0),
b_thread_bufs(lds_read_reg_buf));
});
});
a_blockwise_copy.Run(
a_grid_desc, a_grid_buf, a_block_desc, a_block_buf.At(lds_write_buf));
b_blockwise_copy.Run(
b_grid_desc, b_grid_buf, b_block_desc, b_block_buf.At(lds_write_buf));
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_ford<Sequence<KRepeat, MRepeat, NRepeat>>{}([&](auto kmn) {
constexpr auto k0 = Number<kmn[Number<0>{}]>{};
constexpr auto m0 = Number<kmn[Number<1>{}]>{};
constexpr auto n0 = Number<kmn[Number<2>{}]>{};
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
@@ -1015,11 +932,74 @@ struct BlockwiseGemmXdlopsDirectLoad_pipeline_v4<BlockGemmPipelineScheduler::Int
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
HotLoopScheduler();
};
HotLoopScheduler();
};
LoopFunc(I1, I1, I0, I0);
LoopFunc(I0, I0, I1, I1);
i += HotloopUnroll;
} while(i < (num_loop - PrefetchStages));
}
auto ReadWriteCompFunc =
[&](auto lds_read_buf, auto lds_read_reg_buf, auto lds_write_buf, auto mfma_reg_buf) {
block_sync_lds_direct_load();
static_for<0, KRepeat, 1>{}([&](auto k) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf.At(lds_read_buf),
a_thread_desc_,
make_tuple(m0, I0, k, I0),
a_thread_bufs(lds_read_reg_buf));
});
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf.At(lds_read_buf),
b_thread_desc_,
make_tuple(n0, I0, k, I0),
b_thread_bufs(lds_read_reg_buf));
});
});
a_blockwise_copy.Run(
a_grid_desc, a_grid_buf, a_block_desc, a_block_buf.At(lds_write_buf));
b_blockwise_copy.Run(
b_grid_desc, b_grid_buf, b_block_desc, b_block_buf.At(lds_write_buf));
static_ford<Sequence<KRepeat, MRepeat, NRepeat>>{}([&](auto kmn) {
constexpr auto k0 = Number<kmn[Number<0>{}]>{};
constexpr auto m0 = Number<kmn[Number<1>{}]>{};
constexpr auto n0 = Number<kmn[Number<2>{}]>{};
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_bufs[mfma_reg_buf][Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_bufs[mfma_reg_buf][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
HotLoopScheduler();
};
auto ReadCompFunc = [&](auto lds_read_buf, auto lds_read_reg_buf, auto mfma_reg_buf) {
block_sync_lds_direct_load();
@@ -1043,64 +1023,60 @@ struct BlockwiseGemmXdlopsDirectLoad_pipeline_v4<BlockGemmPipelineScheduler::Int
});
});
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_ford<Sequence<KRepeat, MRepeat, NRepeat>>{}([&](auto kmn) {
constexpr auto k0 = Number<kmn[Number<0>{}]>{};
constexpr auto m0 = Number<kmn[Number<1>{}]>{};
constexpr auto n0 = Number<kmn[Number<2>{}]>{};
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_bufs[mfma_reg_buf][Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_bufs[mfma_reg_buf][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_bufs[mfma_reg_buf][Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_bufs[mfma_reg_buf][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
HotLoopScheduler();
};
auto CompFunc = [&](auto mfma_reg_buf) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_ford<Sequence<KRepeat, MRepeat, NRepeat>>{}([&](auto kmn) {
constexpr auto k0 = Number<kmn[Number<0>{}]>{};
constexpr auto m0 = Number<kmn[Number<1>{}]>{};
constexpr auto n0 = Number<kmn[Number<2>{}]>{};
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_bufs[mfma_reg_buf][Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_bufs[mfma_reg_buf][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_bufs[mfma_reg_buf][Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_bufs[mfma_reg_buf][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
};
// tail

View File

@@ -356,23 +356,23 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra
// Local prefetch 1
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf.At(I0),
a_thread_desc_,
make_tuple(m0, I0, k, I0),
a_thread_bufs(I0));
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf.At(I0),
b_scale_thread_bufs(I0)[n0],
b_thread_desc_,
make_tuple(n0, I0, k, I0),
b_thread_bufs(I0));
});
static_ford<Sequence<KRepeat, MRepeat>>{}([&](auto km) {
constexpr auto k = Number<km[Number<0>{}]>{};
constexpr auto m0 = Number<km[Number<1>{}]>{};
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf.At(I0),
a_thread_desc_,
make_tuple(m0, I0, k, I0),
a_thread_bufs(I0));
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf.At(I0),
b_scale_thread_bufs(I0)[n0],
b_thread_desc_,
make_tuple(n0, I0, k, I0),
b_thread_bufs(I0));
});
});
@@ -477,80 +477,10 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_bufs[mfma_reg_buf]
[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_bufs[mfma_reg_buf]
[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
HotLoopScheduler();
};
LoopFunc(I1, I1, I0, I0);
LoopFunc(I0, I0, I1, I1);
i += HotloopUnroll;
} while(i < (num_loop - PrefetchStages));
}
auto ReadWriteCompFunc = [&](auto lds_read_buf,
auto lds_read_reg_buf,
auto lds_write_buf,
auto mfma_reg_buf) {
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf.At(lds_read_buf),
a_thread_desc_,
make_tuple(m0, I0, k, I0),
a_thread_bufs(lds_read_reg_buf));
});
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf.At(lds_read_buf),
b_scale_thread_bufs(lds_read_buf)[n0],
b_thread_desc_,
make_tuple(n0, I0, k, I0),
b_thread_bufs(lds_read_reg_buf));
});
});
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(lds_write_buf));
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(lds_write_buf));
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_ford<Sequence<KRepeat, MRepeat, NRepeat>>{}([&](auto kmn) {
constexpr auto k0 = Number<kmn[Number<0>{}]>{};
constexpr auto m0 = Number<kmn[Number<1>{}]>{};
constexpr auto n0 = Number<kmn[Number<2>{}]>{};
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
@@ -573,11 +503,73 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
HotLoopScheduler();
};
HotLoopScheduler();
};
LoopFunc(I1, I1, I0, I0);
LoopFunc(I0, I0, I1, I1);
i += HotloopUnroll;
} while(i < (num_loop - PrefetchStages));
}
auto ReadWriteCompFunc =
[&](auto lds_read_buf, auto lds_read_reg_buf, auto lds_write_buf, auto mfma_reg_buf) {
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf.At(lds_read_buf),
a_thread_desc_,
make_tuple(m0, I0, k, I0),
a_thread_bufs(lds_read_reg_buf));
});
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf.At(lds_read_buf),
b_scale_thread_bufs(lds_read_buf)[n0],
b_thread_desc_,
make_tuple(n0, I0, k, I0),
b_thread_bufs(lds_read_reg_buf));
});
});
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(lds_write_buf));
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(lds_write_buf));
static_ford<Sequence<KRepeat, MRepeat, NRepeat>>{}([&](auto kmn) {
constexpr auto k0 = Number<kmn[Number<0>{}]>{};
constexpr auto m0 = Number<kmn[Number<1>{}]>{};
constexpr auto n0 = Number<kmn[Number<2>{}]>{};
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_bufs[mfma_reg_buf][Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_bufs[mfma_reg_buf][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
HotLoopScheduler();
};
auto ReadCompFunc = [&](auto lds_read_buf, auto lds_read_reg_buf, auto mfma_reg_buf) {
block_sync_lds();
@@ -602,64 +594,60 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra
});
});
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_ford<Sequence<KRepeat, MRepeat, NRepeat>>{}([&](auto kmn) {
constexpr auto k0 = Number<kmn[Number<0>{}]>{};
constexpr auto m0 = Number<kmn[Number<1>{}]>{};
constexpr auto n0 = Number<kmn[Number<2>{}]>{};
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_bufs[mfma_reg_buf][Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_bufs[mfma_reg_buf][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_bufs[mfma_reg_buf][Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_bufs[mfma_reg_buf][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
HotLoopScheduler();
};
auto CompFunc = [&](auto mfma_reg_buf) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_ford<Sequence<KRepeat, MRepeat, NRepeat>>{}([&](auto kmn) {
constexpr auto k0 = Number<kmn[Number<0>{}]>{};
constexpr auto m0 = Number<kmn[Number<1>{}]>{};
constexpr auto n0 = Number<kmn[Number<2>{}]>{};
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_bufs[mfma_reg_buf][Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_bufs[mfma_reg_buf][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_bufs[mfma_reg_buf][Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_bufs[mfma_reg_buf][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
};

View File

@@ -590,27 +590,26 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
});
});
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) = a_thread_buf
[Number<a_thread_desc_.CalculateOffset(make_tuple(m0, I0, I0, ik))>{}];
});
static_for<0, KPack, 1>{}([&](auto ik) {
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) = b_thread_buf
[Number<b_thread_desc_.CalculateOffset(make_tuple(n0, I0, I0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
static_ford<Sequence<MRepeat, NRepeat>>{}([&](auto mn) {
constexpr auto m0 = Number<mn[Number<0>{}]>{};
constexpr auto n0 = Number<mn[Number<1>{}]>{};
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) = a_thread_buf
[Number<a_thread_desc_.CalculateOffset(make_tuple(m0, I0, I0, ik))>{}];
});
static_for<0, KPack, 1>{}([&](auto ik) {
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) = b_thread_buf
[Number<b_thread_desc_.CalculateOffset(make_tuple(n0, I0, I0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
HotLoopScheduler();

View File

@@ -365,56 +365,50 @@ struct BlockwiseGemmWMMA
}
else
{
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KPerBlock / KPack, 1>{}([&](auto k) { // k=0,1,2 instead of
// k=0,kpack*1, ..
// read B
b_thread_copy_.Run(
b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<k * KPack / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, n0, I0, I0, I0, I0),
b_thread_buf);
// read A
a_thread_copy_.Run(
a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<k * KPack / A_K1 / A_KRow>{}, m0, I0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, m0, I0, I0, I0, I0),
a_thread_buf);
static_ford<Sequence<NRepeat, MRepeat, KPerBlock / KPack>>{}([&](auto nmk) {
constexpr auto n0 = Number<nmk[Number<0>{}]>{};
constexpr auto m0 = Number<nmk[Number<1>{}]>{};
constexpr auto k = Number<nmk[Number<2>{}]>{}; // k=0,1,2 instead of k=0,kpack*1, ..
// read B
b_thread_copy_.Run(
b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<k * KPack / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, n0, I0, I0, I0, I0),
b_thread_buf);
// read A
a_thread_copy_.Run(
a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<k * KPack / A_K1 / A_KRow>{}, m0, I0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, m0, I0, I0, I0, I0),
a_thread_buf);
vector_type<FloatA, KPack / A_KRow> a_thread_vec;
vector_type<FloatB, KPack / B_KRow> b_thread_vec;
vector_type<FloatA, KPack / A_KRow> a_thread_vec;
vector_type<FloatB, KPack / B_KRow> b_thread_vec;
static_for<0, KPack / A_KRow, 1>{}([&](auto i) {
a_thread_vec.template AsType<FloatA>()(i) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(i / A_K1, m0, 0, 0, 0, i % A_K1))>{}];
});
static_for<0, KPack / B_KRow, 1>{}([&](auto i) {
b_thread_vec.template AsType<FloatB>()(i) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(i / B_K1, n0, 0, 0, 0, i % B_K1))>{}];
});
using wmma_input_type_a =
typename vector_type<FloatA, WmmaK / A_KRow>::type;
using wmma_input_type_b =
typename vector_type<FloatB, WmmaK / B_KRow>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
wmma_gemm.template Run<>(
a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
static_for<0, KPack / A_KRow, 1>{}([&](auto i) {
a_thread_vec.template AsType<FloatA>()(i) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(i / A_K1, m0, 0, 0, 0, i % A_K1))>{}];
});
static_for<0, KPack / B_KRow, 1>{}([&](auto i) {
b_thread_vec.template AsType<FloatB>()(i) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(i / B_K1, n0, 0, 0, 0, i % B_K1))>{}];
});
using wmma_input_type_a = typename vector_type<FloatA, WmmaK / A_KRow>::type;
using wmma_input_type_b = typename vector_type<FloatB, WmmaK / B_KRow>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
wmma_gemm.template Run<>(a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
}
}
@@ -862,60 +856,47 @@ struct BlockwiseGemmWMMA
}
else
{
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KPerBlock / KPack, 1>{}([&](auto k) { // k=0,1,2 instead of
// k=0,kpack*1, ..
// read B
b_thread_copy_.Run(
b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<k * KPack / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, n0, I0, I0, I0, I0),
b_thread_buf);
// read A
a_thread_copy_.Run(
a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<k * KPack / A_K1 / A_KRow>{}, m0, I0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, m0, I0, I0, I0, I0),
a_thread_buf);
static_ford<Sequence<NRepeat, MRepeat, KPerBlock / KPack>>{}([&](auto nmk) {
constexpr auto n0 = Number<nmk[Number<0>{}]>{};
constexpr auto m0 = Number<nmk[Number<1>{}]>{};
constexpr auto k = Number<nmk[Number<2>{}]>{}; // k=0,1,2 instead of k=0,kpack*1, ..
// read B
b_thread_copy_.Run(
b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<k * KPack / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, n0, I0, I0, I0, I0),
b_thread_buf);
// read A
a_thread_copy_.Run(
a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<k * KPack / A_K1 / A_KRow>{}, m0, I0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, m0, I0, I0, I0, I0),
a_thread_buf);
vector_type<FloatA, KPack> a_thread_vec;
vector_type<FloatB, KPack> b_thread_vec;
vector_type<FloatA, KPack> a_thread_vec;
vector_type<FloatB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto i) {
b_thread_vec.template AsType<FloatB>()(i) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(i / B_K1 / B_KRow,
n0,
0,
(i / B_K1) % B_KRow,
0,
i % B_K1))>{}];
a_thread_vec.template AsType<FloatA>()(i) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(i / A_K1 / A_KRow,
m0,
0,
(i / A_K1) % A_KRow,
0,
i % A_K1))>{}];
});
using wmma_input_type_a = typename vector_type<FloatA, WmmaK>::type;
using wmma_input_type_b = typename vector_type<FloatB, WmmaK>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
static_for<0, KPack, 1>{}([&](auto i) {
b_thread_vec.template AsType<FloatB>()(i) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(make_tuple(
i / B_K1 / B_KRow, n0, 0, (i / B_K1) % B_KRow, 0, i % B_K1))>{}];
a_thread_vec.template AsType<FloatA>()(i) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple(
i / A_K1 / A_KRow, m0, 0, (i / A_K1) % A_KRow, 0, i % A_K1))>{}];
});
using wmma_input_type_a = typename vector_type<FloatA, WmmaK>::type;
using wmma_input_type_b = typename vector_type<FloatB, WmmaK>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
}
}

View File

@@ -514,54 +514,54 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
__builtin_amdgcn_sched_barrier(0);
}
static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ElementDataTypeA, KPack> a_thread_vec;
vector_type<ElementDataTypeB, KPack> b_thread_vec;
static_ford<Sequence<MRepeat, NRepeat>>{}([&](auto mn) {
constexpr auto m0 = Number<mn[Number<0>{}]>{};
constexpr auto n0 = Number<mn[Number<1>{}]>{};
vector_type<ElementDataTypeA, KPack> a_thread_vec;
vector_type<ElementDataTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto i) {
a_thread_vec.template AsType<ElementDataTypeA>()(i) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, 0, 0, k_ + i))>{}];
b_thread_vec.template AsType<ElementDataTypeB>()(i) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, 0, 0, k_ + i))>{}];
});
using mfma_input_type_a =
typename vector_type<ElementDataTypeA, xdlops_gemm.K1PerXdlops>::type;
using mfma_input_type_b =
typename vector_type<ElementDataTypeB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
// The block_sync_lds() here performs double duty:
// A) safeguard against data hazard because barrier from blockwise_gemm is
// moved here B) reduce VMEM FIFO congestion by applying small delays to
// different wavefronts It is performed near the end of MAC cluster to
// minimize lgkmcnt penalty
if constexpr(k.value == KPerThread - KPerInnerLoop &&
k_.value == KPerInnerLoop - KPack && m0.value == MRepeat - 1 &&
n0.value == NRepeat - 1)
{
__builtin_amdgcn_sched_barrier(0);
block_sync_lds();
__builtin_amdgcn_sched_barrier(0);
}
// TODO: insert setprio in more precise manner since we
// could have more than >1 MFMA instructions in single call
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type_a>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0)
{
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_setprio(1);
__builtin_amdgcn_sched_barrier(0);
}
static_for<0, KPack, 1>{}([&](auto i) {
a_thread_vec.template AsType<ElementDataTypeA>()(i) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, 0, 0, k_ + i))>{}];
b_thread_vec.template AsType<ElementDataTypeB>()(i) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, 0, 0, k_ + i))>{}];
});
using mfma_input_type_a =
typename vector_type<ElementDataTypeA, xdlops_gemm.K1PerXdlops>::type;
using mfma_input_type_b =
typename vector_type<ElementDataTypeB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
// The block_sync_lds() here performs double duty:
// A) safeguard against data hazard because barrier from blockwise_gemm is
// moved here B) reduce VMEM FIFO congestion by applying small delays to
// different wavefronts It is performed near the end of MAC cluster to
// minimize lgkmcnt penalty
if constexpr(k.value == KPerThread - KPerInnerLoop &&
k_.value == KPerInnerLoop - KPack && m0.value == MRepeat - 1 &&
n0.value == NRepeat - 1)
{
__builtin_amdgcn_sched_barrier(0);
block_sync_lds();
__builtin_amdgcn_sched_barrier(0);
}
// TODO: insert setprio in more precise manner since we
// could have more than >1 MFMA instructions in single call
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type_a>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0)
{
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_setprio(1);
__builtin_amdgcn_sched_barrier(0);
}
});
});
__builtin_amdgcn_sched_barrier(0);
@@ -953,44 +953,43 @@ struct BlockwiseGemmXdlops_v2
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
b_thread_desc_.GetElementSpaceSize());
static_for<0, KPerThread / KPack, 1>{}([&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ...
static_for<0, MRepeat, 1>{}([&](auto m0) {
// read A
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf,
a_thread_desc_,
static_ford<Sequence<KPerThread / KPack, MRepeat>>{}([&](auto km) {
constexpr auto k = Number<km[Number<0>{}]>{};
constexpr auto m0 = Number<km[Number<1>{}]>{};
// read A
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf,
a_thread_desc_,
make_tuple(I0, I0, I0, I0),
a_thread_buf);
static_for<0, NRepeat, 1>{}([&](auto n0) {
// read B
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf,
b_thread_desc_,
make_tuple(I0, I0, I0, I0),
a_thread_buf);
b_thread_buf);
vector_type<FloatAB, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec;
static_for<0, NRepeat, 1>{}([&](auto n0) {
// read B
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf,
b_thread_desc_,
make_tuple(I0, I0, I0, I0),
b_thread_buf);
vector_type<FloatAB, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto i) {
a_thread_vec.template AsType<FloatAB>()(i) = a_thread_buf
[Number<a_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, i))>{}];
b_thread_vec.template AsType<FloatAB>()(i) = b_thread_buf
[Number<b_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, i))>{}];
});
using mfma_input_type =
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
static_for<0, KPack, 1>{}([&](auto i) {
a_thread_vec.template AsType<FloatAB>()(i) = a_thread_buf
[Number<a_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, i))>{}];
b_thread_vec.template AsType<FloatAB>()(i) = b_thread_buf
[Number<b_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, i))>{}];
});
using mfma_input_type =
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
}

View File

@@ -98,13 +98,13 @@ struct BlockwiseSoftmax
});
// calculate exp for elements, P=exp(s-max)
static_for<0, MRepeat, 1>{}([&](auto iM) {
static_for<0, KRepeat, 1>{}([&](auto iK) {
auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{};
in_thread_buf(offset) = IgnoreNaN && ck::math::isnan(in_thread_buf[offset])
? 0
: math::exp(in_thread_buf[offset] - max_value_buf(iM));
});
static_ford<Sequence<MRepeat, KRepeat>>{}([&](auto ii) {
constexpr auto iM = Number<ii[Number<0>{}]>{};
constexpr auto iK = Number<ii[Number<1>{}]>{};
auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{};
in_thread_buf(offset) = IgnoreNaN && ck::math::isnan(in_thread_buf[offset])
? 0
: math::exp(in_thread_buf[offset] - max_value_buf(iM));
});
// sum data