Gemm+Reduce Fusion (#128)

* add gridwise gemm v4r1

* rename

* adding gemm+reduce

* adding gemm+reduce

* adding gemm+reduce

* adding gemm+reduce

* use sfc in shuffling

* remove hardcode

* remove hardcode

* refactor

* fix build

* adding gemm+reduce

* adding gemm+reduce

* adding gemm+reduce

* adding gemm+reduce

* adding gemm+reduce

* format

* clean

* adding gemm+reduce

* adding profiler for gemm+reduce

* adding gemm+reduce profiler

* fix build

* clean up

* gemm+reduce

* fix build

* update DeviceGemm_Xdl_CShuffle; update enum to enum class

* clean up

* add test for gemm+reduce

* clean up

* refactor

* fix build

* fix build
This commit is contained in:
Chao Liu
2022-03-23 22:18:42 -05:00
committed by GitHub
parent f91579aab6
commit f95267f166
56 changed files with 4429 additions and 297 deletions

View File

@@ -84,7 +84,7 @@ static std::vector<T> getTypeValuesFromString(const char* cstr_values)
return (values);
}
typedef enum
enum struct appDataType_t
{
appHalf = 0,
appFloat = 1,
@@ -93,7 +93,7 @@ typedef enum
appInt8x4 = 4,
appBFloat16 = 5,
appDouble = 6,
} appDataType_t;
};
static void check_reduce_dims(const int rank, const std::vector<int>& reduceDims)
{
@@ -131,8 +131,8 @@ class AppArgs
std::vector<float> scales;
ReduceTensorOp_t reduceOp = ReduceTensorOp_t::ADD;
appDataType_t compTypeId = appFloat;
appDataType_t outTypeId = appFloat;
appDataType_t compTypeId = appDataType_t::appFloat;
appDataType_t outTypeId = appDataType_t::appFloat;
bool compType_assigned = false;
bool outType_assigned = false;
@@ -339,15 +339,16 @@ int profile_reduce(int argc, char* argv[])
if(args.use_half)
{
if(!args.compType_assigned)
args.compTypeId = appHalf;
args.compTypeId = appDataType_t::appHalf;
if(args.outType_assigned && (args.outTypeId != appHalf && args.outTypeId != appFloat))
args.outTypeId = appFloat;
if(args.outType_assigned &&
(args.outTypeId != appDataType_t::appHalf && args.outTypeId != appDataType_t::appFloat))
args.outTypeId = appDataType_t::appFloat;
if(!args.outType_assigned)
args.outTypeId = appHalf;
args.outTypeId = appDataType_t::appHalf;
if(args.compTypeId == appHalf)
if(args.compTypeId == appDataType_t::appHalf)
{
profile_reduce_impl<ck::half_t, ck::half_t, ck::half_t>(args.do_verification,
args.init_method,
@@ -362,7 +363,7 @@ int profile_reduce(int argc, char* argv[])
args.scales[0],
args.scales[1]);
}
else if(args.compTypeId == appFloat)
else if(args.compTypeId == appDataType_t::appFloat)
{
profile_reduce_impl<ck::half_t, float, ck::half_t>(args.do_verification,
args.init_method,
@@ -398,15 +399,16 @@ int profile_reduce(int argc, char* argv[])
else if(args.use_int8)
{
if(!args.compType_assigned)
args.compTypeId = appInt8;
args.compTypeId = appDataType_t::appInt8;
if(args.outType_assigned && (args.outTypeId != appInt8 && args.outTypeId != appInt32))
args.outTypeId = appInt32;
if(args.outType_assigned &&
(args.outTypeId != appDataType_t::appInt8 && args.outTypeId != appDataType_t::appInt32))
args.outTypeId = appDataType_t::appInt32;
if(!args.outType_assigned)
args.outTypeId = appInt8;
args.outTypeId = appDataType_t::appInt8;
if(args.compTypeId == appInt8)
if(args.compTypeId == appDataType_t::appInt8)
{
profile_reduce_impl<int8_t, int8_t, int8_t>(args.do_verification,
args.init_method,
@@ -421,7 +423,7 @@ int profile_reduce(int argc, char* argv[])
args.scales[0],
args.scales[1]);
}
else if(args.compTypeId == appInt32)
else if(args.compTypeId == appDataType_t::appInt32)
{
profile_reduce_impl<int8_t, int32_t, int8_t>(args.do_verification,
args.init_method,
@@ -441,11 +443,12 @@ int profile_reduce(int argc, char* argv[])
}
else if(args.use_bf16)
{
if(args.outType_assigned && (args.outTypeId != appBFloat16 && args.outTypeId != appFloat))
args.outTypeId = appFloat;
if(args.outType_assigned && (args.outTypeId != appDataType_t::appBFloat16 &&
args.outTypeId != appDataType_t::appFloat))
args.outTypeId = appDataType_t::appFloat;
if(!args.outType_assigned)
args.outTypeId = appBFloat16;
args.outTypeId = appDataType_t::appBFloat16;
profile_reduce_impl<ck::bhalf_t, float, ck::bhalf_t>(args.do_verification,
args.init_method,
@@ -462,7 +465,7 @@ int profile_reduce(int argc, char* argv[])
}
else
{
if(args.compTypeId == appFloat)
if(args.compTypeId == appDataType_t::appFloat)
{
profile_reduce_impl<float, float, float>(args.do_verification,
args.init_method,
@@ -477,7 +480,7 @@ int profile_reduce(int argc, char* argv[])
args.scales[0],
args.scales[1]);
}
else if(args.compTypeId == appDouble)
else if(args.compTypeId == appDataType_t::appDouble)
{
profile_reduce_impl<float, double, float>(args.do_verification,
args.init_method,