mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
Add int4 reduction examples (#372)
* Add int4 reduction examples
* Contain all using of int4_t inside the pre-compiling condition checking
[ROCm/composable_kernel commit: d520d0cfc1]
This commit is contained in:
@@ -225,6 +225,28 @@ int main(int argc, char* argv[])
|
||||
arg.scales[0],
|
||||
arg.scales[1]);
|
||||
}
|
||||
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|
||||
else if(arg.data_type == 7)
|
||||
{
|
||||
pass = reduce_blockwise_test<int4_t, int32_t, ReduceTensorOp::AVG, false, false>(
|
||||
arg.do_verification,
|
||||
arg.init_method,
|
||||
arg.time_kernel,
|
||||
arg.inLengths,
|
||||
arg.reduceDims,
|
||||
arg.scales[0],
|
||||
arg.scales[1]);
|
||||
|
||||
pass = pass && reduce_blockwise_test<int4_t, int8_t, ReduceTensorOp::MAX, false, false>(
|
||||
arg.do_verification,
|
||||
arg.init_method,
|
||||
arg.time_kernel,
|
||||
arg.inLengths,
|
||||
arg.reduceDims,
|
||||
arg.scales[0],
|
||||
arg.scales[1]);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -251,6 +273,15 @@ int main(int argc, char* argv[])
|
||||
pass && reduce_blockwise_test<int8_t, int32_t, ReduceOpId, PropagateNan, OutputIndex>(
|
||||
true, 2, true, {16, 64, 32, 960}, {0, 1, 2}, 1.0f, 0.0f);
|
||||
|
||||
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|
||||
// for testing int4_t using AVG operation
|
||||
pass = pass && reduce_blockwise_test<int4_t, int32_t, ReduceTensorOp::AVG, false, false>(
|
||||
true, 2, true, {16, 64, 32, 960}, {0, 1, 2}, 1.0f, 0.0f);
|
||||
|
||||
// for testing int4_t using MAX operation
|
||||
pass = pass && reduce_blockwise_test<int4_t, int8_t, ReduceTensorOp::MAX, false, false>(
|
||||
true, 2, true, {16, 64, 32, 960}, {0, 1, 2}, 1.0f, 0.0f);
|
||||
#endif
|
||||
// for testing 3D input
|
||||
pass = pass && reduce_blockwise_test<float, float, ReduceOpId, PropagateNan, OutputIndex>(
|
||||
true, 2, true, {16, 64, 960}, {0, 1}, 1.0f, 0.0f);
|
||||
|
||||
@@ -58,28 +58,47 @@ int reduce_blockwise_impl(bool do_verification,
|
||||
std::is_same<InOutDataType, float>::value &&
|
||||
(op_support_indices && !std::is_same<AccDataType, float>::value);
|
||||
|
||||
// 1) If InOutDataType is int8_t, must use int8_t as AccDataType for indexable reduction
|
||||
// operations 2) If InOutDataType is int8_t, must use int32_t as AccDataType for non-indexable
|
||||
// reduction operations
|
||||
// 1) If InOutDataType is int8_t or int4_t, must use int8_t as AccDataType for indexable
|
||||
// reduction operations 2) If InOutDataType is int8_t or int4_t, must use int32_t as AccDataType
|
||||
// for non-indexable reduction operations
|
||||
constexpr bool invalid_reduce_4 =
|
||||
std::is_same<InOutDataType, int8_t>::value &&
|
||||
((!op_support_indices && !std::is_same<AccDataType, int32_t>::value) ||
|
||||
(op_support_indices && !std::is_same<AccDataType, int8_t>::value));
|
||||
|
||||
// 1) If InOutDataType is int8_t, the supported operation must be either indexable operations or
|
||||
// ADD/AVG
|
||||
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|
||||
constexpr bool invalid_reduce_4_2 =
|
||||
std::is_same<InOutDataType, int4_t>::value &&
|
||||
((!op_support_indices && !std::is_same<AccDataType, int32_t>::value) ||
|
||||
(op_support_indices && !std::is_same<AccDataType, int8_t>::value));
|
||||
#endif
|
||||
|
||||
// 1) If InOutDataType is int8_t or int4_t, the supported operation must be either indexable
|
||||
// operations or ADD/AVG
|
||||
constexpr bool invalid_reduce_5 = std::is_same<InOutDataType, int8_t>::value &&
|
||||
(!op_support_indices && ReduceOpId != ReduceTensorOp::ADD &&
|
||||
ReduceOpId != ReduceTensorOp::AVG);
|
||||
|
||||
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|
||||
constexpr bool invalid_reduce_5_2 = std::is_same<InOutDataType, int4_t>::value &&
|
||||
(!op_support_indices && ReduceOpId != ReduceTensorOp::ADD &&
|
||||
ReduceOpId != ReduceTensorOp::AVG);
|
||||
#endif
|
||||
|
||||
// 1) If InOutDataType is bhalf_t, must use float as AccDataType for all reduction operations
|
||||
constexpr bool invalid_reduce_6 =
|
||||
std::is_same<InOutDataType, bhalf_t>::value && !std::is_same<AccDataType, float>::value;
|
||||
|
||||
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|
||||
constexpr bool invalid_reduce =
|
||||
(invalid_reduce_1 || invalid_reduce_2 || invalid_reduce_3 || invalid_reduce_4 ||
|
||||
invalid_reduce_5 || invalid_reduce_6 || invalid_reduce_4_2 || invalid_reduce_5_2);
|
||||
#else
|
||||
constexpr bool invalid_reduce = (invalid_reduce_1 || invalid_reduce_2 || invalid_reduce_3 ||
|
||||
invalid_reduce_4 || invalid_reduce_5 || invalid_reduce_6);
|
||||
#endif
|
||||
|
||||
if(invalid_reduce)
|
||||
if constexpr(invalid_reduce)
|
||||
{
|
||||
std::cerr << "The reduction setting is invalid, exiting!" << std::endl;
|
||||
return (-1);
|
||||
@@ -91,10 +110,17 @@ int reduce_blockwise_impl(bool do_verification,
|
||||
using AccElementwiseOperation =
|
||||
typename reduce_unary_operator<ReduceOpId, true, true>::AccElementwiseOperation;
|
||||
|
||||
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|
||||
using InOutDataTypeInDevice = typename std::
|
||||
conditional<std::is_same<InOutDataType, int4_t>::value, int8_t, InOutDataType>::type;
|
||||
#else
|
||||
using InOutDataTypeInDevice = InOutDataType;
|
||||
#endif
|
||||
|
||||
using DeviceReduceInstance =
|
||||
ck::tensor_operation::device::DeviceReduceMultiBlock<InOutDataType,
|
||||
ck::tensor_operation::device::DeviceReduceMultiBlock<InOutDataTypeInDevice,
|
||||
AccDataType,
|
||||
InOutDataType,
|
||||
InOutDataTypeInDevice,
|
||||
Rank,
|
||||
NumReduceDim,
|
||||
ReduceOperation,
|
||||
@@ -166,13 +192,35 @@ int reduce_blockwise_impl(bool do_verification,
|
||||
};
|
||||
|
||||
// these buffers are usually provided by the user application
|
||||
DeviceMem in_dev(sizeof(InOutDataType) * in.mDesc.GetElementSpaceSize());
|
||||
DeviceMem out_dev(sizeof(InOutDataType) * out.mDesc.GetElementSpaceSize());
|
||||
DeviceMem in_dev(sizeof(InOutDataTypeInDevice) * in.mDesc.GetElementSpaceSize());
|
||||
DeviceMem out_dev(sizeof(InOutDataTypeInDevice) * out.mDesc.GetElementSpaceSize());
|
||||
|
||||
in_dev.ToDevice(in.mData.data());
|
||||
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|
||||
if(std::is_same<InOutDataType, int4_t>::value)
|
||||
{
|
||||
std::vector<InOutDataTypeInDevice> tmp_buf(in.mData.size());
|
||||
|
||||
std::copy_n(in.mData.data(), in.mData.size(), tmp_buf.data());
|
||||
in_dev.ToDevice(tmp_buf.data());
|
||||
}
|
||||
else
|
||||
#endif
|
||||
in_dev.ToDevice(in.mData.data());
|
||||
|
||||
if(beta != 0.0f)
|
||||
out_dev.ToDevice(out.mData.data());
|
||||
{
|
||||
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|
||||
if(std::is_same<InOutDataType, int4_t>::value)
|
||||
{
|
||||
std::vector<InOutDataTypeInDevice> tmp_buf(in.mData.size());
|
||||
|
||||
std::copy_n(out.mData.data(), out.mData.size(), tmp_buf.data());
|
||||
out_dev.ToDevice(tmp_buf.data());
|
||||
}
|
||||
else
|
||||
#endif
|
||||
out_dev.ToDevice(out.mData.data());
|
||||
};
|
||||
|
||||
size_t indicesSizeInBytes = OutputIndex ? out.mDesc.GetElementSize() * sizeof(int32_t) : 0;
|
||||
|
||||
@@ -261,7 +309,19 @@ int reduce_blockwise_impl(bool do_verification,
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
out_dev.FromDevice(out.mData.data());
|
||||
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|
||||
if(std::is_same<InOutDataType, int4_t>::value)
|
||||
{
|
||||
std::vector<InOutDataTypeInDevice> tmp_buf(out.mData.size());
|
||||
|
||||
out_dev.FromDevice(tmp_buf.data());
|
||||
|
||||
std::copy_n(tmp_buf.data(), out.mData.size(), out.mData.data());
|
||||
}
|
||||
else
|
||||
#endif
|
||||
out_dev.FromDevice(out.mData.data());
|
||||
|
||||
pass = pass && ck::utils::check_err(out.mData, out_ref.mData);
|
||||
|
||||
if(OutputIndex)
|
||||
|
||||
Reference in New Issue
Block a user