From 24eab229957d0c2b0fa22c68c2989271c7f0c9ad Mon Sep 17 00:00:00 2001 From: Qianfeng Date: Fri, 26 Aug 2022 05:58:48 +0800 Subject: [PATCH] Add int4 reduction examples (#372) * Add int4 reduction examples * Contain all using of int4_t inside the pre-compiling condition checking [ROCm/composable_kernel commit: d520d0cfc1ed1bda8a6a8e2caedcbe6232064217] --- example/12_reduce/reduce_blockwise.cpp | 31 ++++++++ example/12_reduce/reduce_blockwise_impl.hpp | 86 +++++++++++++++++---- 2 files changed, 104 insertions(+), 13 deletions(-) diff --git a/example/12_reduce/reduce_blockwise.cpp b/example/12_reduce/reduce_blockwise.cpp index 7cebbefb62..c1bcdbb826 100644 --- a/example/12_reduce/reduce_blockwise.cpp +++ b/example/12_reduce/reduce_blockwise.cpp @@ -225,6 +225,28 @@ int main(int argc, char* argv[]) arg.scales[0], arg.scales[1]); } +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 + else if(arg.data_type == 7) + { + pass = reduce_blockwise_test( + arg.do_verification, + arg.init_method, + arg.time_kernel, + arg.inLengths, + arg.reduceDims, + arg.scales[0], + arg.scales[1]); + + pass = pass && reduce_blockwise_test( + arg.do_verification, + arg.init_method, + arg.time_kernel, + arg.inLengths, + arg.reduceDims, + arg.scales[0], + arg.scales[1]); + } +#endif } else { @@ -251,6 +273,15 @@ int main(int argc, char* argv[]) pass && reduce_blockwise_test( true, 2, true, {16, 64, 32, 960}, {0, 1, 2}, 1.0f, 0.0f); +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 + // for testing int4_t using AVG operation + pass = pass && reduce_blockwise_test( + true, 2, true, {16, 64, 32, 960}, {0, 1, 2}, 1.0f, 0.0f); + + // for testing int4_t using MAX operation + pass = pass && reduce_blockwise_test( + true, 2, true, {16, 64, 32, 960}, {0, 1, 2}, 1.0f, 0.0f); +#endif // for testing 3D input pass = pass && reduce_blockwise_test( true, 2, true, {16, 64, 960}, {0, 1}, 1.0f, 0.0f); diff --git a/example/12_reduce/reduce_blockwise_impl.hpp b/example/12_reduce/reduce_blockwise_impl.hpp index c185773f63..ef5ec99481 100644 --- a/example/12_reduce/reduce_blockwise_impl.hpp +++ b/example/12_reduce/reduce_blockwise_impl.hpp @@ -58,28 +58,47 @@ int reduce_blockwise_impl(bool do_verification, std::is_same::value && (op_support_indices && !std::is_same::value); - // 1) If InOutDataType is int8_t, must use int8_t as AccDataType for indexable reduction - // operations 2) If InOutDataType is int8_t, must use int32_t as AccDataType for non-indexable - // reduction operations + // 1) If InOutDataType is int8_t or int4_t, must use int8_t as AccDataType for indexable + // reduction operations 2) If InOutDataType is int8_t or int4_t, must use int32_t as AccDataType + // for non-indexable reduction operations constexpr bool invalid_reduce_4 = std::is_same::value && ((!op_support_indices && !std::is_same::value) || (op_support_indices && !std::is_same::value)); - // 1) If InOutDataType is int8_t, the supported operation must be either indexable operations or - // ADD/AVG +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 + constexpr bool invalid_reduce_4_2 = + std::is_same::value && + ((!op_support_indices && !std::is_same::value) || + (op_support_indices && !std::is_same::value)); +#endif + + // 1) If InOutDataType is int8_t or int4_t, the supported operation must be either indexable + // operations or ADD/AVG constexpr bool invalid_reduce_5 = std::is_same::value && (!op_support_indices && ReduceOpId != ReduceTensorOp::ADD && ReduceOpId != ReduceTensorOp::AVG); +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 + constexpr bool invalid_reduce_5_2 = std::is_same::value && + (!op_support_indices && ReduceOpId != ReduceTensorOp::ADD && + ReduceOpId != ReduceTensorOp::AVG); +#endif + // 1) If InOutDataType is bhalf_t, must use float as AccDataType for all reduction operations constexpr bool invalid_reduce_6 = std::is_same::value && !std::is_same::value; +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 + constexpr bool invalid_reduce = + (invalid_reduce_1 || invalid_reduce_2 || invalid_reduce_3 || invalid_reduce_4 || + invalid_reduce_5 || invalid_reduce_6 || invalid_reduce_4_2 || invalid_reduce_5_2); +#else constexpr bool invalid_reduce = (invalid_reduce_1 || invalid_reduce_2 || invalid_reduce_3 || invalid_reduce_4 || invalid_reduce_5 || invalid_reduce_6); +#endif - if(invalid_reduce) + if constexpr(invalid_reduce) { std::cerr << "The reduction setting is invalid, exiting!" << std::endl; return (-1); @@ -91,10 +110,17 @@ int reduce_blockwise_impl(bool do_verification, using AccElementwiseOperation = typename reduce_unary_operator::AccElementwiseOperation; +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 + using InOutDataTypeInDevice = typename std:: + conditional::value, int8_t, InOutDataType>::type; +#else + using InOutDataTypeInDevice = InOutDataType; +#endif + using DeviceReduceInstance = - ck::tensor_operation::device::DeviceReduceMultiBlock::value) + { + std::vector tmp_buf(in.mData.size()); + + std::copy_n(in.mData.data(), in.mData.size(), tmp_buf.data()); + in_dev.ToDevice(tmp_buf.data()); + } + else +#endif + in_dev.ToDevice(in.mData.data()); if(beta != 0.0f) - out_dev.ToDevice(out.mData.data()); + { +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 + if(std::is_same::value) + { + std::vector tmp_buf(in.mData.size()); + + std::copy_n(out.mData.data(), out.mData.size(), tmp_buf.data()); + out_dev.ToDevice(tmp_buf.data()); + } + else +#endif + out_dev.ToDevice(out.mData.data()); + }; size_t indicesSizeInBytes = OutputIndex ? out.mDesc.GetElementSize() * sizeof(int32_t) : 0; @@ -261,7 +309,19 @@ int reduce_blockwise_impl(bool do_verification, if(do_verification) { - out_dev.FromDevice(out.mData.data()); +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 + if(std::is_same::value) + { + std::vector tmp_buf(out.mData.size()); + + out_dev.FromDevice(tmp_buf.data()); + + std::copy_n(tmp_buf.data(), out.mData.size(), out.mData.data()); + } + else +#endif + out_dev.FromDevice(out.mData.data()); + pass = pass && ck::utils::check_err(out.mData, out_ref.mData); if(OutputIndex)