mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 01:36:06 +00:00
Gemm+Reduce Fusion (#128)
* add gridwise gemm v4r1 * rename * adding gemm+reduce * adding gemm+reduce * adding gemm+reduce * adding gemm+reduce * use sfc in shuffling * remove hardcode * remove hardcode * refactor * fix build * adding gemm+reduce * adding gemm+reduce * adding gemm+reduce * adding gemm+reduce * adding gemm+reduce * format * clean * adding gemm+reduce * adding profiler for gemm+reduce * adding gemm+reduce profiler * fix build * clean up * gemm+reduce * fix build * update DeviceGemm_Xdl_CShuffle; update enum to enum class * clean up * add test for gemm+reduce * clean up * refactor * fix build * fix build
This commit is contained in:
@@ -84,7 +84,7 @@ static std::vector<T> getTypeValuesFromString(const char* cstr_values)
|
||||
return (values);
|
||||
}
|
||||
|
||||
typedef enum
|
||||
enum struct appDataType_t
|
||||
{
|
||||
appHalf = 0,
|
||||
appFloat = 1,
|
||||
@@ -93,7 +93,7 @@ typedef enum
|
||||
appInt8x4 = 4,
|
||||
appBFloat16 = 5,
|
||||
appDouble = 6,
|
||||
} appDataType_t;
|
||||
};
|
||||
|
||||
static void check_reduce_dims(const int rank, const std::vector<int>& reduceDims)
|
||||
{
|
||||
@@ -131,8 +131,8 @@ class AppArgs
|
||||
std::vector<float> scales;
|
||||
|
||||
ReduceTensorOp_t reduceOp = ReduceTensorOp_t::ADD;
|
||||
appDataType_t compTypeId = appFloat;
|
||||
appDataType_t outTypeId = appFloat;
|
||||
appDataType_t compTypeId = appDataType_t::appFloat;
|
||||
appDataType_t outTypeId = appDataType_t::appFloat;
|
||||
|
||||
bool compType_assigned = false;
|
||||
bool outType_assigned = false;
|
||||
@@ -339,15 +339,16 @@ int profile_reduce(int argc, char* argv[])
|
||||
if(args.use_half)
|
||||
{
|
||||
if(!args.compType_assigned)
|
||||
args.compTypeId = appHalf;
|
||||
args.compTypeId = appDataType_t::appHalf;
|
||||
|
||||
if(args.outType_assigned && (args.outTypeId != appHalf && args.outTypeId != appFloat))
|
||||
args.outTypeId = appFloat;
|
||||
if(args.outType_assigned &&
|
||||
(args.outTypeId != appDataType_t::appHalf && args.outTypeId != appDataType_t::appFloat))
|
||||
args.outTypeId = appDataType_t::appFloat;
|
||||
|
||||
if(!args.outType_assigned)
|
||||
args.outTypeId = appHalf;
|
||||
args.outTypeId = appDataType_t::appHalf;
|
||||
|
||||
if(args.compTypeId == appHalf)
|
||||
if(args.compTypeId == appDataType_t::appHalf)
|
||||
{
|
||||
profile_reduce_impl<ck::half_t, ck::half_t, ck::half_t>(args.do_verification,
|
||||
args.init_method,
|
||||
@@ -362,7 +363,7 @@ int profile_reduce(int argc, char* argv[])
|
||||
args.scales[0],
|
||||
args.scales[1]);
|
||||
}
|
||||
else if(args.compTypeId == appFloat)
|
||||
else if(args.compTypeId == appDataType_t::appFloat)
|
||||
{
|
||||
profile_reduce_impl<ck::half_t, float, ck::half_t>(args.do_verification,
|
||||
args.init_method,
|
||||
@@ -398,15 +399,16 @@ int profile_reduce(int argc, char* argv[])
|
||||
else if(args.use_int8)
|
||||
{
|
||||
if(!args.compType_assigned)
|
||||
args.compTypeId = appInt8;
|
||||
args.compTypeId = appDataType_t::appInt8;
|
||||
|
||||
if(args.outType_assigned && (args.outTypeId != appInt8 && args.outTypeId != appInt32))
|
||||
args.outTypeId = appInt32;
|
||||
if(args.outType_assigned &&
|
||||
(args.outTypeId != appDataType_t::appInt8 && args.outTypeId != appDataType_t::appInt32))
|
||||
args.outTypeId = appDataType_t::appInt32;
|
||||
|
||||
if(!args.outType_assigned)
|
||||
args.outTypeId = appInt8;
|
||||
args.outTypeId = appDataType_t::appInt8;
|
||||
|
||||
if(args.compTypeId == appInt8)
|
||||
if(args.compTypeId == appDataType_t::appInt8)
|
||||
{
|
||||
profile_reduce_impl<int8_t, int8_t, int8_t>(args.do_verification,
|
||||
args.init_method,
|
||||
@@ -421,7 +423,7 @@ int profile_reduce(int argc, char* argv[])
|
||||
args.scales[0],
|
||||
args.scales[1]);
|
||||
}
|
||||
else if(args.compTypeId == appInt32)
|
||||
else if(args.compTypeId == appDataType_t::appInt32)
|
||||
{
|
||||
profile_reduce_impl<int8_t, int32_t, int8_t>(args.do_verification,
|
||||
args.init_method,
|
||||
@@ -441,11 +443,12 @@ int profile_reduce(int argc, char* argv[])
|
||||
}
|
||||
else if(args.use_bf16)
|
||||
{
|
||||
if(args.outType_assigned && (args.outTypeId != appBFloat16 && args.outTypeId != appFloat))
|
||||
args.outTypeId = appFloat;
|
||||
if(args.outType_assigned && (args.outTypeId != appDataType_t::appBFloat16 &&
|
||||
args.outTypeId != appDataType_t::appFloat))
|
||||
args.outTypeId = appDataType_t::appFloat;
|
||||
|
||||
if(!args.outType_assigned)
|
||||
args.outTypeId = appBFloat16;
|
||||
args.outTypeId = appDataType_t::appBFloat16;
|
||||
|
||||
profile_reduce_impl<ck::bhalf_t, float, ck::bhalf_t>(args.do_verification,
|
||||
args.init_method,
|
||||
@@ -462,7 +465,7 @@ int profile_reduce(int argc, char* argv[])
|
||||
}
|
||||
else
|
||||
{
|
||||
if(args.compTypeId == appFloat)
|
||||
if(args.compTypeId == appDataType_t::appFloat)
|
||||
{
|
||||
profile_reduce_impl<float, float, float>(args.do_verification,
|
||||
args.init_method,
|
||||
@@ -477,7 +480,7 @@ int profile_reduce(int argc, char* argv[])
|
||||
args.scales[0],
|
||||
args.scales[1]);
|
||||
}
|
||||
else if(args.compTypeId == appDouble)
|
||||
else if(args.compTypeId == appDataType_t::appDouble)
|
||||
{
|
||||
profile_reduce_impl<float, double, float>(args.do_verification,
|
||||
args.init_method,
|
||||
|
||||
Reference in New Issue
Block a user