[Navi3x] Add fp16/int8 wmma conv forward instances (#746)

* fix wmma gemm int8; add grouped conv int8 example

* Add int8 gemm-bilinear instances

* compile sanity check unknown

* Sanity pass + clang-format

* add int8 conv profiler instances

* solve merge conflict

---------

Co-authored-by: zjing14 <zhangjing14@gmail.com>
Co-authored-by: Chao Liu <chao.liu2@amd.com>
This commit is contained in:
Haocong WANG
2023-09-08 10:59:26 +08:00
committed by GitHub
parent 37a8c1f756
commit 562b4cec48
19 changed files with 1192 additions and 42 deletions

View File

@@ -599,7 +599,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
// check if it's 1x1, stride=1 conv
for(index_t i = 0; i < NDimSpatial; ++i)
{
const index_t X = arg.b_g_k_c_xs_lengths_[i + 2];
const index_t X = arg.b_g_k_c_xs_lengths_[i + 3];
const index_t ConvStride = arg.conv_filter_strides_[i];
const index_t LeftPad = arg.input_left_pads_[i];
const index_t RightPad = arg.input_right_pads_[i];
@@ -616,7 +616,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
// check if it's 1x1 conv
for(index_t i = 0; i < NDimSpatial; ++i)
{
const index_t X = arg.b_g_k_c_xs_lengths_[i + 2];
const index_t X = arg.b_g_k_c_xs_lengths_[i + 3];
const index_t LeftPad = arg.input_left_pads_[i];
const index_t RightPad = arg.input_right_pads_[i];