[Navi3x] Add fp16/int8 wmma conv forward instances (#746)

* fix wmma gemm int8; add grouped conv int8 example

* Add int8 gemm-bilinear instances

* compile sanity check unknown

* Sanity pass + clang-format

* add int8 conv profiler instances

* solve merge conflict

---------

Co-authored-by: zjing14 <zhangjing14@gmail.com>
Co-authored-by: Chao Liu <chao.liu2@amd.com>
This commit is contained in:
Haocong WANG
2023-09-08 10:59:26 +08:00
committed by GitHub
parent 37a8c1f756
commit 562b4cec48
19 changed files with 1192 additions and 42 deletions

View File

@@ -221,49 +221,102 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatB>(
b_thread_desc_.GetElementSpaceSize());
static_for<0, KPerBlock / WmmaK, 1>{}([&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ...
static_for<0, MRepeat, 1>{}([&](auto m0) {
// read A
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<k * WmmaK / A_K1>{}, m0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, m0, I0, I0, I0),
a_thread_buf);
// basic intrinsic to determine loopover direction
if constexpr(MRepeat < NRepeat)
{
static_for<0, KPerBlock / WmmaK, 1>{}(
[&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ...
static_for<0, MRepeat, 1>{}([&](auto m0) {
// read A
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<k * WmmaK / A_K1>{}, m0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, m0, I0, I0, I0),
a_thread_buf);
static_for<0, NRepeat, 1>{}([&](auto n0) {
// read B
b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<k * WmmaK / B_K1>{}, n0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, n0, I0, I0, I0),
b_thread_buf);
vector_type<FloatA, WmmaK> a_thread_vec;
vector_type<FloatB, WmmaK> b_thread_vec;
static_for<0, NRepeat, 1>{}([&](auto n0) {
// read B
b_thread_copy_.Run(
b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<k * WmmaK / B_K1>{}, n0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, n0, I0, I0, I0),
b_thread_buf);
vector_type<FloatA, WmmaK> a_thread_vec;
vector_type<FloatB, WmmaK> b_thread_vec;
static_for<0, WmmaK, 1>{}([&](auto i) {
a_thread_vec.template AsType<FloatA>()(i) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(i / A_K1, m0, 0, 0, i % A_K1))>{}];
b_thread_vec.template AsType<FloatB>()(i) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(i / B_K1, n0, 0, 0, i % B_K1))>{}];
static_for<0, WmmaK, 1>{}([&](auto i) {
a_thread_vec.template AsType<FloatA>()(i) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(i / A_K1, m0, 0, 0, i % A_K1))>{}];
b_thread_vec.template AsType<FloatB>()(i) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(i / B_K1, n0, 0, 0, i % B_K1))>{}];
});
using wmma_input_type_a = typename vector_type<FloatA, WmmaK>::type;
using wmma_input_type_b = typename vector_type<FloatB, WmmaK>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
wmma_gemm.template Run(
a_thread_vec.template AsType<wmma_input_type_a>()(Number<0>{}),
b_thread_vec.template AsType<wmma_input_type_b>()(Number<0>{}),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
using wmma_input_type_a = typename vector_type<FloatA, WmmaK>::type;
using wmma_input_type_b = typename vector_type<FloatB, WmmaK>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
wmma_gemm.template Run(
a_thread_vec.template AsType<wmma_input_type_a>()(Number<0>{}),
b_thread_vec.template AsType<wmma_input_type_b>()(Number<0>{}),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
}
else
{
static_for<0, KPerBlock / WmmaK, 1>{}(
[&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ...
static_for<0, NRepeat, 1>{}([&](auto n0) {
// read B
b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<k * WmmaK / B_K1>{}, n0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, n0, I0, I0, I0),
b_thread_buf);
static_for<0, MRepeat, 1>{}([&](auto m0) {
// read A
a_thread_copy_.Run(
a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<k * WmmaK / A_K1>{}, m0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, m0, I0, I0, I0),
a_thread_buf);
vector_type<FloatA, WmmaK> a_thread_vec;
vector_type<FloatB, WmmaK> b_thread_vec;
static_for<0, WmmaK, 1>{}([&](auto i) {
a_thread_vec.template AsType<FloatA>()(i) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(i / A_K1, m0, 0, 0, i % A_K1))>{}];
b_thread_vec.template AsType<FloatB>()(i) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(i / B_K1, n0, 0, 0, i % B_K1))>{}];
});
using wmma_input_type_a = typename vector_type<FloatA, WmmaK>::type;
using wmma_input_type_b = typename vector_type<FloatB, WmmaK>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
wmma_gemm.template Run(
a_thread_vec.template AsType<wmma_input_type_a>()(Number<0>{}),
b_thread_vec.template AsType<wmma_input_type_b>()(Number<0>{}),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
}
}
protected: