mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 09:16:52 +00:00
Reduction for int8 and bfloat16 (#125)
* Use thread cluster descriptor and explicit M_K 2d descriptor to simply Blockwise Reduction * Change by replacing ReduceDims by NumReduceDims as Device Reduce interface template parameter * Rename the folder name for the pool2d and reduce examples * Update to reduction test scripts * Add Readme for pool2d_fwd and reduce_blockwise examples * Add support for int8_t reduction (ADD/AVG, MIN/MAX/AMAX) * Tiny fix in reduce profiler and tiny update in reduce testing scripts * Tiny fix in testing script profile_reduce_no_index.sh * Tiny fix in testing script profile_reduce_no_index.sh * Add support for bfp16 reduction (using bhalf_t = ushort) * Tiny fix in amd_buffer_addressing.hpp * Tiny change in script/profile_reduce_with_index.sh * Use AccDataType for Beta value and use element_wise::PassThrough * Use type_convert for type converting in host layer reduction * Renaming and refining in Reduction profiler/device layer/examples * Renaming and refining in Reduction profiler/device layer/examples * Renaming all NumReduceDims to NumReduceDim * Fix the leaked type_convert in ThreadwiseTensorSliceTransfer_v2 * Update to testing scripts to add bf16 support * added more static_assert * Remove buggy tunable configurations defined in device_reduce_instance_xxx.hpp * Add static_assert to give compile-time warning for incorrect thread slice-size/vector-size configurations * minor change * Refine and fix (in GetWorkspaceSizeInBytes of MultiBlockPartialReduce) to make int8 completely pass * Tiny renaming in gridwise_2d_reduction_multiblock_partial_reduce.hpp * Tiny fix in script/profile_reduce_no_index.sh * Refine in DeviceReduce layer with regard to using NumInvariantDim/NumReduceDim or InvariantDims/ReduceDims * Generic renaming in host reduction and DeviceReduce layer * Add support for 4-d all dimension reduction in the profiler and add_device_reduce_xxx instances * Use multi-thread and simplification for host Reduction implementation * Add ctest for reduction * Update to clarify the using of data init method in produce_reduce/example_reduce/test_reduce/ * Update to the reduce CTest executables to enable default testing behavior when no command argument * Renaming Co-authored-by: Jianfeng yan <jfyan008@gmail.com>
This commit is contained in:
@@ -34,6 +34,8 @@ static struct option long_options[] = {{"inLengths", required_argument, nullptr,
|
||||
{"scales", required_argument, nullptr, 'S'},
|
||||
{"half", no_argument, nullptr, '?'},
|
||||
{"double", no_argument, nullptr, '?'},
|
||||
{"int8", no_argument, nullptr, '?'},
|
||||
{"bf16", no_argument, nullptr, '?'},
|
||||
{"dumpout", required_argument, nullptr, 'o'},
|
||||
{"verify", required_argument, nullptr, 'v'},
|
||||
{"log", required_argument, nullptr, 'l'},
|
||||
@@ -119,6 +121,8 @@ class AppArgs
|
||||
public:
|
||||
bool use_half = false;
|
||||
bool use_double = false;
|
||||
bool use_int8 = false;
|
||||
bool use_bf16 = false;
|
||||
|
||||
std::vector<size_t> inLengths;
|
||||
std::vector<size_t> outLengths;
|
||||
@@ -169,6 +173,8 @@ class AppArgs
|
||||
<< std::endl;
|
||||
std::cout << "--half, use fp16 for the input and output tensor data types" << std::endl;
|
||||
std::cout << "--double, use fp64 for the input and output tensor data types" << std::endl;
|
||||
std::cout << "--int8, use int8 for the input and output tensor data types" << std::endl;
|
||||
std::cout << "--bf16, use bfloat16 for the input and output tensor data types" << std::endl;
|
||||
std::cout << "--verify or -v, 1/0 to indicate whether to verify the reduction result by "
|
||||
"comparing with the host-based reduction"
|
||||
<< std::endl;
|
||||
@@ -267,6 +273,10 @@ class AppArgs
|
||||
use_half = true;
|
||||
else if(std::string(long_options[option_index].name) == "double")
|
||||
use_double = true;
|
||||
else if(std::string(long_options[option_index].name) == "int8")
|
||||
use_int8 = true;
|
||||
else if(std::string(long_options[option_index].name) == "bf16")
|
||||
use_bf16 = true;
|
||||
else if(std::string(long_options[option_index].name) == "help")
|
||||
{
|
||||
show_usage(argv[0]);
|
||||
@@ -385,6 +395,71 @@ int profile_reduce(int argc, char* argv[])
|
||||
args.scales[0],
|
||||
args.scales[1]);
|
||||
}
|
||||
else if(args.use_int8)
|
||||
{
|
||||
if(!args.compType_assigned)
|
||||
args.compTypeId = appInt8;
|
||||
|
||||
if(args.outType_assigned && (args.outTypeId != appInt8 && args.outTypeId != appInt32))
|
||||
args.outTypeId = appInt32;
|
||||
|
||||
if(!args.outType_assigned)
|
||||
args.outTypeId = appInt8;
|
||||
|
||||
if(args.compTypeId == appInt8)
|
||||
{
|
||||
profile_reduce_impl<int8_t, int8_t, int8_t>(args.do_verification,
|
||||
args.init_method,
|
||||
args.do_log,
|
||||
args.do_dumpout,
|
||||
args.nrepeat,
|
||||
args.inLengths,
|
||||
args.reduceDims,
|
||||
args.reduceOp,
|
||||
args.nanOpt,
|
||||
args.indicesOpt,
|
||||
args.scales[0],
|
||||
args.scales[1]);
|
||||
}
|
||||
else if(args.compTypeId == appInt32)
|
||||
{
|
||||
profile_reduce_impl<int8_t, int32_t, int8_t>(args.do_verification,
|
||||
args.init_method,
|
||||
args.do_log,
|
||||
args.do_dumpout,
|
||||
args.nrepeat,
|
||||
args.inLengths,
|
||||
args.reduceDims,
|
||||
args.reduceOp,
|
||||
args.nanOpt,
|
||||
args.indicesOpt,
|
||||
args.scales[0],
|
||||
args.scales[1]);
|
||||
}
|
||||
else
|
||||
throw std::runtime_error("Invalid compType assignment!");
|
||||
}
|
||||
else if(args.use_bf16)
|
||||
{
|
||||
if(args.outType_assigned && (args.outTypeId != appBFloat16 && args.outTypeId != appFloat))
|
||||
args.outTypeId = appFloat;
|
||||
|
||||
if(!args.outType_assigned)
|
||||
args.outTypeId = appBFloat16;
|
||||
|
||||
profile_reduce_impl<ck::bhalf_t, float, ck::bhalf_t>(args.do_verification,
|
||||
args.init_method,
|
||||
args.do_log,
|
||||
args.do_dumpout,
|
||||
args.nrepeat,
|
||||
args.inLengths,
|
||||
args.reduceDims,
|
||||
args.reduceOp,
|
||||
args.nanOpt,
|
||||
args.indicesOpt,
|
||||
args.scales[0],
|
||||
args.scales[1]);
|
||||
}
|
||||
else
|
||||
{
|
||||
if(args.compTypeId == appFloat)
|
||||
|
||||
Reference in New Issue
Block a user