mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 18:17:44 +00:00
fix wmma gemm int8; add grouped conv int8 example (#716)
[ROCm/composable_kernel commit: 6eef0755c9]
This commit is contained in:
@@ -20,4 +20,5 @@ if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS M
|
||||
endif()
|
||||
if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102")
|
||||
add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_fp16 grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp)
|
||||
add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_int8 grouped_conv_fwd_bias_relu_add_wmma_int8.cpp)
|
||||
endif()
|
||||
|
||||
@@ -0,0 +1,26 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common_wmma.hpp"
|
||||
|
||||
// kernel data types
|
||||
using InKernelDataType = I8;
|
||||
using WeiKernelDataType = I8;
|
||||
using AccDataType = I32;
|
||||
using CShuffleDataType = I8;
|
||||
using BiasKernelDataType = I8;
|
||||
using ResidualKernelDataType = I8;
|
||||
using OutKernelDataType = I8;
|
||||
|
||||
// tensor data types
|
||||
using InUserDataType = InKernelDataType;
|
||||
using WeiUserDataType = WeiKernelDataType;
|
||||
using OutUserDataType = OutKernelDataType;
|
||||
|
||||
using InElementOp = PassThrough;
|
||||
using WeiElementOp = PassThrough;
|
||||
using OutElementOp = ck::tensor_operation::element_wise::AddReluAdd;
|
||||
|
||||
#include "run_grouped_conv_fwd_bias_relu_add_wmma_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_grouped_conv_fwd_bias_relu_add_example(argc, argv); }
|
||||
@@ -262,12 +262,12 @@ struct wmma_type<WmmaInstr::wmma_i32_16x16x16_iu8,
|
||||
|
||||
template <index_t MPerWmma,
|
||||
index_t NPerWmma,
|
||||
bool neg_a,
|
||||
bool neg_b,
|
||||
bool clamp,
|
||||
class FloatA,
|
||||
class FloatB,
|
||||
class FloatC>
|
||||
class FloatC,
|
||||
bool neg_a = false,
|
||||
bool neg_b = false,
|
||||
bool clamp = false>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
if constexpr(wave_size == 32)
|
||||
|
||||
Reference in New Issue
Block a user