mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 20:51:23 +00:00
[Navi3x] Add fp16/int8 wmma conv forward instances (#746)
* fix wmma gemm int8; add grouped conv int8 example * Add int8 gemm-bilinear instances * compile sanity check unknown * Sanity pass + clang-format * add int8 conv profiler instances * solve merge conflict --------- Co-authored-by: zjing14 <zhangjing14@gmail.com> Co-authored-by: Chao Liu <chao.liu2@amd.com>
This commit is contained in:
@@ -221,49 +221,102 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatB>(
|
||||
b_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
static_for<0, KPerBlock / WmmaK, 1>{}([&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ...
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
// read A
|
||||
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
|
||||
make_tuple(Number<k * WmmaK / A_K1>{}, m0, I0, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(I0, m0, I0, I0, I0),
|
||||
a_thread_buf);
|
||||
// basic intrinsic to determine loopover direction
|
||||
if constexpr(MRepeat < NRepeat)
|
||||
{
|
||||
static_for<0, KPerBlock / WmmaK, 1>{}(
|
||||
[&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ...
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
// read A
|
||||
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
|
||||
make_tuple(Number<k * WmmaK / A_K1>{}, m0, I0, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(I0, m0, I0, I0, I0),
|
||||
a_thread_buf);
|
||||
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
// read B
|
||||
b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
|
||||
make_tuple(Number<k * WmmaK / B_K1>{}, n0, I0, I0, I0),
|
||||
b_block_buf,
|
||||
b_thread_desc_,
|
||||
make_tuple(I0, n0, I0, I0, I0),
|
||||
b_thread_buf);
|
||||
vector_type<FloatA, WmmaK> a_thread_vec;
|
||||
vector_type<FloatB, WmmaK> b_thread_vec;
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
// read B
|
||||
b_thread_copy_.Run(
|
||||
b_block_desc_k0_n0_n1_n2_k1,
|
||||
make_tuple(Number<k * WmmaK / B_K1>{}, n0, I0, I0, I0),
|
||||
b_block_buf,
|
||||
b_thread_desc_,
|
||||
make_tuple(I0, n0, I0, I0, I0),
|
||||
b_thread_buf);
|
||||
vector_type<FloatA, WmmaK> a_thread_vec;
|
||||
vector_type<FloatB, WmmaK> b_thread_vec;
|
||||
|
||||
static_for<0, WmmaK, 1>{}([&](auto i) {
|
||||
a_thread_vec.template AsType<FloatA>()(i) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(i / A_K1, m0, 0, 0, i % A_K1))>{}];
|
||||
b_thread_vec.template AsType<FloatB>()(i) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(i / B_K1, n0, 0, 0, i % B_K1))>{}];
|
||||
static_for<0, WmmaK, 1>{}([&](auto i) {
|
||||
a_thread_vec.template AsType<FloatA>()(i) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(i / A_K1, m0, 0, 0, i % A_K1))>{}];
|
||||
b_thread_vec.template AsType<FloatB>()(i) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(i / B_K1, n0, 0, 0, i % B_K1))>{}];
|
||||
});
|
||||
|
||||
using wmma_input_type_a = typename vector_type<FloatA, WmmaK>::type;
|
||||
using wmma_input_type_b = typename vector_type<FloatB, WmmaK>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
wmma_gemm.template Run(
|
||||
a_thread_vec.template AsType<wmma_input_type_a>()(Number<0>{}),
|
||||
b_thread_vec.template AsType<wmma_input_type_b>()(Number<0>{}),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
|
||||
using wmma_input_type_a = typename vector_type<FloatA, WmmaK>::type;
|
||||
using wmma_input_type_b = typename vector_type<FloatB, WmmaK>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
wmma_gemm.template Run(
|
||||
a_thread_vec.template AsType<wmma_input_type_a>()(Number<0>{}),
|
||||
b_thread_vec.template AsType<wmma_input_type_b>()(Number<0>{}),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
static_for<0, KPerBlock / WmmaK, 1>{}(
|
||||
[&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ...
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
// read B
|
||||
b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
|
||||
make_tuple(Number<k * WmmaK / B_K1>{}, n0, I0, I0, I0),
|
||||
b_block_buf,
|
||||
b_thread_desc_,
|
||||
make_tuple(I0, n0, I0, I0, I0),
|
||||
b_thread_buf);
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
// read A
|
||||
a_thread_copy_.Run(
|
||||
a_block_desc_k0_m0_m1_m2_k1,
|
||||
make_tuple(Number<k * WmmaK / A_K1>{}, m0, I0, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(I0, m0, I0, I0, I0),
|
||||
a_thread_buf);
|
||||
vector_type<FloatA, WmmaK> a_thread_vec;
|
||||
vector_type<FloatB, WmmaK> b_thread_vec;
|
||||
|
||||
static_for<0, WmmaK, 1>{}([&](auto i) {
|
||||
a_thread_vec.template AsType<FloatA>()(i) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(i / A_K1, m0, 0, 0, i % A_K1))>{}];
|
||||
b_thread_vec.template AsType<FloatB>()(i) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(i / B_K1, n0, 0, 0, i % B_K1))>{}];
|
||||
});
|
||||
|
||||
using wmma_input_type_a = typename vector_type<FloatA, WmmaK>::type;
|
||||
using wmma_input_type_b = typename vector_type<FloatB, WmmaK>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
wmma_gemm.template Run(
|
||||
a_thread_vec.template AsType<wmma_input_type_a>()(Number<0>{}),
|
||||
b_thread_vec.template AsType<wmma_input_type_b>()(Number<0>{}),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
|
||||
Reference in New Issue
Block a user