mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
[MIOpen Downstream] Fix Reduction Kernel (#34)
* Tiny fix in using data type template parameters in blockwise and direct_threadwise kernel * Fix with regard to implementing GetZeroVal() in both kernel and host * Avoid convert to compType from dstDataType before writting the output value * Add half_t support to NumericLimits and make constexpr GetZeroVal() of binary operator * Add CONSTANT decorator for descriptor read buffer * Use get_thread_local_1d_id() for thread local Id * Rename GetZeroVal() to GetReductionZeroVal() in the kernels * Remove constexpr from initialized zeroVal and tiny fix in reduction_operator.hpp * Occasional tiny simplification and update in the kernel files * Update to re-order tensor dimensions on the host, split second_call kernel wrapper files and simplify reduce_all kernel wrappers * Update to remove OpenCL tidy checking failures * Update for better readability * Remove unused codes and not-needed template parameters in the kernel wrappers Co-authored-by: Chao Liu <chao.liu2@amd.com>
This commit is contained in:
@@ -92,7 +92,7 @@ struct GridwiseReduction_xy_to_x_blockwise
|
||||
// LDS
|
||||
__shared__ compType p_in_block_buffer[BlockBufferSize];
|
||||
|
||||
auto zeroVal = opReduce::GetZeroVal();
|
||||
const auto zeroVal = opReduce::GetReductionZeroVal();
|
||||
|
||||
const auto src_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>{}(zeroVal));
|
||||
@@ -180,6 +180,10 @@ struct GridwiseReduction_xy_to_x_blockwise
|
||||
if(!float_equal_one{}(alpha))
|
||||
accuValue_buf(I0) *= type_convert<compType>{}(alpha);
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf;
|
||||
|
||||
dstValue_buf(I0) = type_convert<dstDataType>{}(accuValue_buf[I0]);
|
||||
|
||||
if(!float_equal_zero{}(beta))
|
||||
{
|
||||
auto threadwise_dst_load =
|
||||
@@ -200,11 +204,11 @@ struct GridwiseReduction_xy_to_x_blockwise
|
||||
threadwise_dst_load.Run(
|
||||
dst1dDesc, dst_global_buf, ReducedDataDesc, make_tuple(I0), priorDstValue_buf);
|
||||
|
||||
accuValue_buf(I0) += type_convert<compType>{}(priorDstValue_buf[I0] * beta);
|
||||
dstValue_buf(I0) += priorDstValue_buf[I0] * beta;
|
||||
}
|
||||
|
||||
auto threadwise_dst_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<compType,
|
||||
ThreadwiseTensorSliceTransfer_v1r3<dstDataType,
|
||||
dstDataType,
|
||||
decltype(ReducedDataDesc),
|
||||
dst1dDescType,
|
||||
@@ -218,7 +222,7 @@ struct GridwiseReduction_xy_to_x_blockwise
|
||||
make_multi_index(block_global_1d_id));
|
||||
|
||||
threadwise_dst_store.Run(
|
||||
ReducedDataDesc, make_tuple(I0), accuValue_buf, dst1dDesc, dst_global_buf);
|
||||
ReducedDataDesc, make_tuple(I0), dstValue_buf, dst1dDesc, dst_global_buf);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -239,7 +243,7 @@ struct GridwiseReduction_xy_to_x_blockwise
|
||||
__shared__ compType p_in_block_buffer[BlockBufferSize];
|
||||
__shared__ int block_indices_buffer[BlockBufferSize];
|
||||
|
||||
auto zeroVal = opReduce::GetZeroVal();
|
||||
const auto zeroVal = opReduce::GetReductionZeroVal();
|
||||
|
||||
const auto src_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>{}(zeroVal));
|
||||
@@ -281,7 +285,7 @@ struct GridwiseReduction_xy_to_x_blockwise
|
||||
ThreadClusterLengths,
|
||||
Sequence<0, 1>,
|
||||
srcDataType,
|
||||
dstDataType,
|
||||
compType,
|
||||
src2dDescType,
|
||||
decltype(in_block_desc),
|
||||
Sequence<0, 1>,
|
||||
@@ -345,6 +349,10 @@ struct GridwiseReduction_xy_to_x_blockwise
|
||||
if(!float_equal_one{}(alpha))
|
||||
accuValue_buf(I0) *= type_convert<compType>{}(alpha);
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf;
|
||||
|
||||
dstValue_buf(I0) = type_convert<dstDataType>{}(accuValue_buf[I0]);
|
||||
|
||||
if(!float_equal_zero{}(beta))
|
||||
{
|
||||
auto threadwise_dst_load =
|
||||
@@ -368,11 +376,11 @@ struct GridwiseReduction_xy_to_x_blockwise
|
||||
make_tuple(I0),
|
||||
priorDstValue_buf);
|
||||
|
||||
accuValue_buf(I0) += type_convert<compType>{}(priorDstValue_buf[I0] * beta);
|
||||
dstValue_buf(I0) += priorDstValue_buf[I0] * beta;
|
||||
}
|
||||
|
||||
auto threadwise_dst_val_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<compType,
|
||||
ThreadwiseTensorSliceTransfer_v1r3<dstDataType,
|
||||
dstDataType,
|
||||
decltype(ReducedDataDesc),
|
||||
dst1dDescType,
|
||||
@@ -400,7 +408,7 @@ struct GridwiseReduction_xy_to_x_blockwise
|
||||
make_multi_index(block_global_1d_id));
|
||||
|
||||
threadwise_dst_val_store.Run(
|
||||
ReducedDataDesc, make_tuple(I0), accuValue_buf, dst1dDesc, dst_global_val_buf);
|
||||
ReducedDataDesc, make_tuple(I0), dstValue_buf, dst1dDesc, dst_global_val_buf);
|
||||
threadwise_dst_idx_store.Run(
|
||||
ReducedDataDesc, make_tuple(I0), accuIndex_buf, dst1dDesc, dst_global_idx_buf);
|
||||
}
|
||||
@@ -423,7 +431,7 @@ struct GridwiseReduction_xy_to_x_blockwise
|
||||
__shared__ compType p_in_block_buffer[BlockBufferSize];
|
||||
__shared__ int block_indices_buffer[BlockBufferSize];
|
||||
|
||||
auto zeroVal = opReduce::GetZeroVal();
|
||||
const auto zeroVal = opReduce::GetReductionZeroVal();
|
||||
|
||||
const auto src_global_val_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum_t::Global>(ws_values_global,
|
||||
@@ -547,6 +555,10 @@ struct GridwiseReduction_xy_to_x_blockwise
|
||||
if(!float_equal_one{}(alpha))
|
||||
accuValue_buf(I0) *= type_convert<compType>{}(alpha);
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf;
|
||||
|
||||
dstValue_buf(I0) = type_convert<dstDataType>{}(accuValue_buf[I0]);
|
||||
|
||||
if(!float_equal_zero{}(beta))
|
||||
{
|
||||
auto threadwise_dst_load =
|
||||
@@ -570,11 +582,11 @@ struct GridwiseReduction_xy_to_x_blockwise
|
||||
make_tuple(I0),
|
||||
priorDstValue_buf);
|
||||
|
||||
accuValue_buf(I0) += type_convert<compType>{}(priorDstValue_buf[I0] * beta);
|
||||
dstValue_buf(I0) += priorDstValue_buf[I0] * beta;
|
||||
}
|
||||
|
||||
auto threadwise_dst_val_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<compType,
|
||||
ThreadwiseTensorSliceTransfer_v1r3<dstDataType,
|
||||
dstDataType,
|
||||
decltype(ReducedDataDesc),
|
||||
dst1dDescType,
|
||||
@@ -602,7 +614,7 @@ struct GridwiseReduction_xy_to_x_blockwise
|
||||
make_multi_index(block_global_1d_id));
|
||||
|
||||
threadwise_dst_val_store.Run(
|
||||
ReducedDataDesc, make_tuple(I0), accuValue_buf, dst1dDesc, dst_global_val_buf);
|
||||
ReducedDataDesc, make_tuple(I0), dstValue_buf, dst1dDesc, dst_global_val_buf);
|
||||
threadwise_dst_idx_store.Run(
|
||||
ReducedDataDesc, make_tuple(I0), accuIndex_buf, dst1dDesc, dst_global_idx_buf);
|
||||
}
|
||||
|
||||
@@ -82,7 +82,7 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
|
||||
(void)ws_indices_global;
|
||||
(void)indices_global;
|
||||
|
||||
const auto zeroVal = opReduce::GetZeroVal();
|
||||
const auto zeroVal = opReduce::GetReductionZeroVal();
|
||||
|
||||
const auto src_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>{}(zeroVal));
|
||||
@@ -147,6 +147,10 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
|
||||
if(!float_equal_one{}(alpha))
|
||||
accuValue_buf(I0) *= type_convert<compType>{}(alpha);
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf;
|
||||
|
||||
dstValue_buf(I0) = type_convert<dstDataType>{}(accuValue_buf[I0]);
|
||||
|
||||
if(!float_equal_zero{}(beta))
|
||||
{
|
||||
auto threadwise_dst_load = ThreadwiseTensorSliceTransfer_v2<dstDataType,
|
||||
@@ -166,11 +170,11 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
|
||||
threadwise_dst_load.Run(
|
||||
dst1dDesc, dst_global_buf, ReducedDataDesc, make_tuple(I0), priorDstValue_buf);
|
||||
|
||||
accuValue_buf(I0) += type_convert<compType>{}(priorDstValue_buf[I0] * beta);
|
||||
dstValue_buf(I0) += priorDstValue_buf[I0] * beta;
|
||||
}
|
||||
|
||||
auto threadwise_dst_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<compType,
|
||||
ThreadwiseTensorSliceTransfer_v1r3<dstDataType,
|
||||
dstDataType,
|
||||
decltype(ReducedDataDesc),
|
||||
dst1dDescType,
|
||||
@@ -184,7 +188,7 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
|
||||
make_multi_index(thread_global_1d_id));
|
||||
|
||||
threadwise_dst_store.Run(
|
||||
ReducedDataDesc, make_tuple(I0), accuValue_buf, dst1dDesc, dst_global_buf);
|
||||
ReducedDataDesc, make_tuple(I0), dstValue_buf, dst1dDesc, dst_global_buf);
|
||||
};
|
||||
|
||||
template <>
|
||||
@@ -200,7 +204,7 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
|
||||
{
|
||||
(void)ws_indices_global;
|
||||
|
||||
const auto zeroVal = opReduce::GetZeroVal();
|
||||
const auto zeroVal = opReduce::GetReductionZeroVal();
|
||||
|
||||
const auto src_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>{}(zeroVal));
|
||||
@@ -232,7 +236,7 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
|
||||
index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id();
|
||||
|
||||
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<srcDataType,
|
||||
dstDataType,
|
||||
compType,
|
||||
src2dDescType,
|
||||
decltype(ThreadBufferDesc),
|
||||
ThreadBufferLengths,
|
||||
@@ -271,6 +275,10 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
|
||||
if(!float_equal_one{}(alpha))
|
||||
accuValue_buf(I0) *= type_convert<compType>{}(alpha);
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf;
|
||||
|
||||
dstValue_buf(I0) = type_convert<dstDataType>{}(accuValue_buf[I0]);
|
||||
|
||||
if(!float_equal_zero{}(beta))
|
||||
{
|
||||
auto threadwise_dst_load = ThreadwiseTensorSliceTransfer_v2<dstDataType,
|
||||
@@ -290,11 +298,11 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
|
||||
threadwise_dst_load.Run(
|
||||
dst1dDesc, dst_global_val_buf, ReducedDataDesc, make_tuple(I0), priorDstValue_buf);
|
||||
|
||||
accuValue_buf(I0) += type_convert<compType>{}(priorDstValue_buf[I0] * beta);
|
||||
dstValue_buf(I0) += priorDstValue_buf[I0] * beta;
|
||||
}
|
||||
|
||||
auto threadwise_dst_val_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<compType,
|
||||
ThreadwiseTensorSliceTransfer_v1r3<dstDataType,
|
||||
dstDataType,
|
||||
decltype(ReducedDataDesc),
|
||||
dst1dDescType,
|
||||
@@ -322,7 +330,7 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
|
||||
make_multi_index(thread_global_1d_id));
|
||||
|
||||
threadwise_dst_val_store.Run(
|
||||
ReducedDataDesc, make_tuple(I0), accuValue_buf, dst1dDesc, dst_global_val_buf);
|
||||
ReducedDataDesc, make_tuple(I0), dstValue_buf, dst1dDesc, dst_global_val_buf);
|
||||
threadwise_dst_idx_store.Run(
|
||||
ReducedDataDesc, make_tuple(I0), accuIndex_buf, dst1dDesc, dst_global_idx_buf);
|
||||
};
|
||||
@@ -340,7 +348,7 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
|
||||
{
|
||||
(void)origReduceLen;
|
||||
|
||||
const auto zeroVal = opReduce::GetZeroVal();
|
||||
const auto zeroVal = opReduce::GetReductionZeroVal();
|
||||
|
||||
const auto src_global_val_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum_t::Global>(ws_values_global,
|
||||
@@ -377,7 +385,7 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
|
||||
index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id();
|
||||
|
||||
auto threadwise_src_val_load = ThreadwiseTensorSliceTransfer_v2<srcDataType,
|
||||
dstDataType,
|
||||
compType,
|
||||
src2dDescType,
|
||||
decltype(ThreadBufferDesc),
|
||||
ThreadBufferLengths,
|
||||
@@ -430,6 +438,10 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
|
||||
if(!float_equal_one{}(alpha))
|
||||
accuValue_buf(I0) *= type_convert<compType>{}(alpha);
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf;
|
||||
|
||||
dstValue_buf(I0) = type_convert<dstDataType>{}(accuValue_buf[I0]);
|
||||
|
||||
if(!float_equal_zero{}(beta))
|
||||
{
|
||||
auto threadwise_dst_load = ThreadwiseTensorSliceTransfer_v2<dstDataType,
|
||||
@@ -449,11 +461,11 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
|
||||
threadwise_dst_load.Run(
|
||||
dst1dDesc, dst_global_val_buf, ReducedDataDesc, make_tuple(I0), priorDstValue_buf);
|
||||
|
||||
accuValue_buf(I0) += type_convert<compType>{}(priorDstValue_buf[I0] * beta);
|
||||
dstValue_buf(I0) += priorDstValue_buf[I0] * beta;
|
||||
}
|
||||
|
||||
auto threadwise_dst_val_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<compType,
|
||||
ThreadwiseTensorSliceTransfer_v1r3<dstDataType,
|
||||
dstDataType,
|
||||
decltype(ReducedDataDesc),
|
||||
dst1dDescType,
|
||||
@@ -481,7 +493,7 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
|
||||
make_multi_index(thread_global_1d_id));
|
||||
|
||||
threadwise_dst_val_store.Run(
|
||||
ReducedDataDesc, make_tuple(I0), accuValue_buf, dst1dDesc, dst_global_val_buf);
|
||||
ReducedDataDesc, make_tuple(I0), dstValue_buf, dst1dDesc, dst_global_val_buf);
|
||||
threadwise_dst_idx_store.Run(
|
||||
ReducedDataDesc, make_tuple(I0), accuIndex_buf, dst1dDesc, dst_global_idx_buf);
|
||||
};
|
||||
|
||||
@@ -82,7 +82,7 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
|
||||
(void)ws_indices_global;
|
||||
(void)indices_global;
|
||||
|
||||
auto zeroVal = opReduce::GetZeroVal();
|
||||
const auto zeroVal = opReduce::GetReductionZeroVal();
|
||||
|
||||
const auto src_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>{}(zeroVal));
|
||||
@@ -156,6 +156,10 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
|
||||
if(!float_equal_one{}(alpha))
|
||||
accuValue_buf(I0) *= type_convert<compType>{}(alpha);
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf;
|
||||
|
||||
dstValue_buf(I0) = type_convert<dstDataType>{}(accuValue_buf[I0]);
|
||||
|
||||
if(!float_equal_zero{}(beta))
|
||||
{
|
||||
auto threadwise_dst_load =
|
||||
@@ -176,11 +180,11 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
|
||||
threadwise_dst_load.Run(
|
||||
dst1dDesc, dst_global_buf, ReducedDataDesc, make_tuple(I0), priorDstValue_buf);
|
||||
|
||||
accuValue_buf(I0) += type_convert<compType>{}(priorDstValue_buf(I0) * beta);
|
||||
dstValue_buf(I0) += priorDstValue_buf(I0) * beta;
|
||||
}
|
||||
|
||||
auto threadwise_dst_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<compType,
|
||||
ThreadwiseTensorSliceTransfer_v1r3<dstDataType,
|
||||
dstDataType,
|
||||
decltype(ReducedDataDesc),
|
||||
dst1dDescType,
|
||||
@@ -194,7 +198,7 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
|
||||
make_multi_index(warp_global_1d_id));
|
||||
|
||||
threadwise_dst_store.Run(
|
||||
ReducedDataDesc, make_tuple(I0), accuValue_buf, dst1dDesc, dst_global_buf);
|
||||
ReducedDataDesc, make_tuple(I0), dstValue_buf, dst1dDesc, dst_global_buf);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -211,7 +215,7 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
|
||||
{
|
||||
(void)ws_indices_global;
|
||||
|
||||
auto zeroVal = opReduce::GetZeroVal();
|
||||
const auto zeroVal = opReduce::GetReductionZeroVal();
|
||||
|
||||
const auto src_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>{}(zeroVal));
|
||||
@@ -291,6 +295,10 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
|
||||
if(!float_equal_one{}(alpha))
|
||||
accuValue_buf(I0) *= type_convert<compType>{}(alpha);
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf;
|
||||
|
||||
dstValue_buf(I0) = type_convert<dstDataType>{}(accuValue_buf[I0]);
|
||||
|
||||
if(!float_equal_zero{}(beta))
|
||||
{
|
||||
auto threadwise_dst_load =
|
||||
@@ -314,11 +322,11 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
|
||||
make_tuple(I0),
|
||||
priorDstValue_buf);
|
||||
|
||||
accuValue_buf(I0) += type_convert<compType>{}(priorDstValue_buf[I0] * beta);
|
||||
dstValue_buf(I0) += priorDstValue_buf[I0] * beta;
|
||||
}
|
||||
|
||||
auto threadwise_dst_val_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<compType,
|
||||
ThreadwiseTensorSliceTransfer_v1r3<dstDataType,
|
||||
dstDataType,
|
||||
decltype(ReducedDataDesc),
|
||||
dst1dDescType,
|
||||
@@ -346,7 +354,7 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
|
||||
make_multi_index(warp_global_1d_id));
|
||||
|
||||
threadwise_dst_val_store.Run(
|
||||
ReducedDataDesc, make_tuple(I0), accuValue_buf, dst1dDesc, dst_global_val_buf);
|
||||
ReducedDataDesc, make_tuple(I0), dstValue_buf, dst1dDesc, dst_global_val_buf);
|
||||
threadwise_dst_idx_store.Run(
|
||||
ReducedDataDesc, make_tuple(I0), accuIndex_buf, dst1dDesc, dst_global_idx_buf);
|
||||
}
|
||||
@@ -365,7 +373,7 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
|
||||
{
|
||||
(void)origReduceLen;
|
||||
|
||||
auto zeroVal = opReduce::GetZeroVal();
|
||||
const auto zeroVal = opReduce::GetReductionZeroVal();
|
||||
|
||||
const auto src_global_val_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum_t::Global>(ws_values_global,
|
||||
@@ -466,6 +474,10 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
|
||||
if(!float_equal_one{}(alpha))
|
||||
accuValue_buf(I0) *= type_convert<compType>{}(alpha);
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf;
|
||||
|
||||
dstValue_buf(I0) = type_convert<dstDataType>{}(accuValue_buf[I0]);
|
||||
|
||||
if(!float_equal_zero{}(beta))
|
||||
{
|
||||
auto threadwise_dst_load =
|
||||
@@ -489,11 +501,11 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
|
||||
make_tuple(I0),
|
||||
priorDstValue_buf);
|
||||
|
||||
accuValue_buf(I0) += type_convert<compType>{}(priorDstValue_buf[I0] * beta);
|
||||
dstValue_buf(I0) += priorDstValue_buf[I0] * beta;
|
||||
}
|
||||
|
||||
auto threadwise_dst_val_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<compType,
|
||||
ThreadwiseTensorSliceTransfer_v1r3<dstDataType,
|
||||
dstDataType,
|
||||
decltype(ReducedDataDesc),
|
||||
dst1dDescType,
|
||||
@@ -521,7 +533,7 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
|
||||
make_multi_index(warp_global_1d_id));
|
||||
|
||||
threadwise_dst_val_store.Run(
|
||||
ReducedDataDesc, make_tuple(I0), accuValue_buf, dst1dDesc, dst_global_val_buf);
|
||||
ReducedDataDesc, make_tuple(I0), dstValue_buf, dst1dDesc, dst_global_val_buf);
|
||||
threadwise_dst_idx_store.Run(
|
||||
ReducedDataDesc, make_tuple(I0), accuIndex_buf, dst1dDesc, dst_global_idx_buf);
|
||||
}
|
||||
|
||||
@@ -86,7 +86,7 @@ struct GridwiseReduction_xy_to_x_multiblock
|
||||
(void)alpha; // unused
|
||||
(void)beta; // unused
|
||||
|
||||
auto zeroVal = opReduce::GetZeroVal();
|
||||
const auto zeroVal = opReduce::GetReductionZeroVal();
|
||||
|
||||
// LDS
|
||||
__shared__ compType p_in_block_buffer[BlockBufferSize];
|
||||
@@ -216,7 +216,7 @@ struct GridwiseReduction_xy_to_x_multiblock
|
||||
(void)alpha; // unused
|
||||
(void)beta; // unused
|
||||
|
||||
auto zeroVal = opReduce::GetZeroVal();
|
||||
const auto zeroVal = opReduce::GetReductionZeroVal();
|
||||
|
||||
// LDS
|
||||
__shared__ compType p_in_block_values_buffer[BlockBufferSize];
|
||||
|
||||
@@ -56,7 +56,7 @@ struct BlockwiseReduction_2d_block_buffer
|
||||
Reduce(BufferType& block_buffer, index_t toReduceBlocks, compType& accuData)
|
||||
{
|
||||
const index_t thread_local_id = get_thread_local_1d_id();
|
||||
compType lAccuData = opReduce::GetZeroVal();
|
||||
compType lAccuData = opReduce::GetReductionZeroVal();
|
||||
|
||||
index_t offset;
|
||||
for(index_t otherDimInd = 0; otherDimInd < toReduceBlocks; otherDimInd++)
|
||||
@@ -115,7 +115,7 @@ struct BlockwiseReduction_2d_block_buffer
|
||||
int& accuIndex)
|
||||
{
|
||||
const index_t thread_local_id = get_thread_local_1d_id();
|
||||
compType lAccuData = opReduce::GetZeroVal();
|
||||
compType lAccuData = opReduce::GetReductionZeroVal();
|
||||
int lAccuIndex = 0;
|
||||
|
||||
if constexpr(blockIsOneRow)
|
||||
|
||||
@@ -62,7 +62,7 @@ struct WarpReduce
|
||||
// This interface implementation uses HIP built-in device shuffling functions
|
||||
__device__ static void ReduceImpl1(const BufferType& thread_buffer, compType& accuData)
|
||||
{
|
||||
compType lAccuData = opReduce::GetZeroVal();
|
||||
compType lAccuData = opReduce::GetReductionZeroVal();
|
||||
|
||||
static_for<0, ThreadBufferLen, 1>{}(
|
||||
[&](auto I) { binop::calculate(lAccuData, thread_buffer[I]); });
|
||||
@@ -84,7 +84,7 @@ struct WarpReduce
|
||||
// since for fp16, built-in shuffling functions is not provided by HIP
|
||||
__device__ static void ReduceImpl2(const BufferType& thread_buffer, compType& accuData)
|
||||
{
|
||||
compType lAccuData = opReduce::GetZeroVal();
|
||||
compType lAccuData = opReduce::GetReductionZeroVal();
|
||||
|
||||
static_for<0, ThreadBufferLen, 1>{}(
|
||||
[&](auto I) { binop::calculate(lAccuData, thread_buffer[I]); });
|
||||
@@ -138,7 +138,7 @@ struct WarpReduce
|
||||
int& accuIndex,
|
||||
int indexStart)
|
||||
{
|
||||
compType lAccuData = opReduce::GetZeroVal();
|
||||
compType lAccuData = opReduce::GetReductionZeroVal();
|
||||
int lAccuIndex = 0;
|
||||
index_t thread_inwarp_id = get_thread_local_1d_id() % warpSize;
|
||||
|
||||
@@ -170,7 +170,7 @@ struct WarpReduce
|
||||
int& accuIndex,
|
||||
int indexStart)
|
||||
{
|
||||
compType lAccuData = opReduce::GetZeroVal();
|
||||
compType lAccuData = opReduce::GetReductionZeroVal();
|
||||
int lAccuIndex = 0;
|
||||
index_t thread_id = get_thread_local_1d_id();
|
||||
index_t warpId = thread_id / warpSize;
|
||||
@@ -278,7 +278,7 @@ struct WarpReduceWithIndicesInput
|
||||
compType& accuData,
|
||||
int& accuIndex)
|
||||
{
|
||||
compType lAccuData = opReduce::GetZeroVal();
|
||||
compType lAccuData = opReduce::GetReductionZeroVal();
|
||||
int lAccuIndex = 0;
|
||||
|
||||
static_for<0, ThreadBufferLen, 1>{}([&](auto I) {
|
||||
@@ -307,7 +307,7 @@ struct WarpReduceWithIndicesInput
|
||||
compType& accuData,
|
||||
int& accuIndex)
|
||||
{
|
||||
compType lAccuData = opReduce::GetZeroVal();
|
||||
compType lAccuData = opReduce::GetReductionZeroVal();
|
||||
int lAccuIndex = 0;
|
||||
index_t thread_id = get_thread_local_1d_id();
|
||||
index_t warpId = thread_id / warpSize;
|
||||
|
||||
@@ -1008,20 +1008,27 @@ struct inner_product_with_conversion
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct NumericLimits;
|
||||
struct NumericLimits
|
||||
{
|
||||
__host__ __device__ static constexpr T Min() { return std::numeric_limits<T>::min(); }
|
||||
|
||||
__host__ __device__ static constexpr T Max() { return std::numeric_limits<T>::max(); }
|
||||
|
||||
__host__ __device__ static constexpr T Lowest() { return std::numeric_limits<T>::lowest(); }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct NumericLimits<int32_t>
|
||||
struct NumericLimits<half_t>
|
||||
{
|
||||
__host__ __device__ static constexpr int32_t Min()
|
||||
{
|
||||
return std::numeric_limits<int32_t>::min();
|
||||
}
|
||||
static constexpr unsigned short binary_min = 0x0400;
|
||||
static constexpr unsigned short binary_max = 0x7BFF;
|
||||
static constexpr unsigned short binary_lowest = 0xFBFF;
|
||||
|
||||
__host__ __device__ static constexpr int32_t Max()
|
||||
{
|
||||
return std::numeric_limits<int32_t>::max();
|
||||
}
|
||||
__host__ __device__ static constexpr half_t Min() { return as_type<half_t>(binary_min); }
|
||||
|
||||
__host__ __device__ static constexpr half_t Max() { return as_type<half_t>(binary_max); }
|
||||
|
||||
__host__ __device__ static constexpr half_t Lowest() { return as_type<half_t>(binary_lowest); }
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -26,76 +26,25 @@
|
||||
#ifndef CK_REDUCTION_COMMON_HPP
|
||||
#define CK_REDUCTION_COMMON_HPP
|
||||
|
||||
// this enumerate should be synchronized with include/miopen/reduce_common.hpp
|
||||
#include "reduction_enums.hpp"
|
||||
|
||||
namespace ck {
|
||||
enum class ReductionMethod_t
|
||||
{
|
||||
DirectThreadWise = 1,
|
||||
DirectWarpWise = 2,
|
||||
BlockWise = 3,
|
||||
MultiBlock = 4
|
||||
}; // end of namespace ck
|
||||
|
||||
enum class ReduceTensorOp_t
|
||||
{
|
||||
ADD = 0,
|
||||
MUL = 1,
|
||||
MIN = 2,
|
||||
MAX = 3,
|
||||
AMAX = 4,
|
||||
AVG = 5,
|
||||
NORM1 = 6,
|
||||
NORM2 = 7,
|
||||
// MUL_NO_ZEROS = 8,
|
||||
};
|
||||
|
||||
enum class NanPropagation_t
|
||||
{
|
||||
NOT_PROPAGATE_NAN = 0,
|
||||
PROPAGATE_NAN = 1,
|
||||
};
|
||||
|
||||
enum class ReduceTensorIndices_t
|
||||
{
|
||||
NO_INDICES = 0,
|
||||
FLATTENED_INDICES = 1,
|
||||
};
|
||||
|
||||
enum class IndicesType_t
|
||||
{
|
||||
INDICES_32BIT = 0,
|
||||
INDICES_64BIT = 1,
|
||||
INDICES_16BIT = 2,
|
||||
INDICES_8BIT = 3,
|
||||
};
|
||||
|
||||
struct float_equal_one
|
||||
{
|
||||
template <class T>
|
||||
__device__ static inline bool apply(T x)
|
||||
{
|
||||
return x <= type_convert<T>{}(1.0f) and x >= type_convert<T>{}(1.0f);
|
||||
}
|
||||
|
||||
template <class T>
|
||||
__device__ inline bool operator()(T x)
|
||||
{
|
||||
return (float_equal_one::apply(x));
|
||||
return x <= static_cast<T>(1.0f) and x >= static_cast<T>(1.0f);
|
||||
};
|
||||
};
|
||||
|
||||
struct float_equal_zero
|
||||
{
|
||||
template <class T>
|
||||
__device__ static inline bool apply(T x)
|
||||
{
|
||||
return x <= type_convert<T>{}(0.0f) and x >= type_convert<T>{}(0.0f);
|
||||
}
|
||||
|
||||
template <class T>
|
||||
__device__ inline bool operator()(T x)
|
||||
{
|
||||
return (float_equal_zero::apply(x));
|
||||
return x <= static_cast<T>(0.0f) and x >= static_cast<T>(0.0f);
|
||||
};
|
||||
};
|
||||
|
||||
|
||||
66
composable_kernel/include/utility/reduction_enums.hpp
Normal file
66
composable_kernel/include/utility/reduction_enums.hpp
Normal file
@@ -0,0 +1,66 @@
|
||||
/*******************************************************************************
|
||||
*
|
||||
* MIT License
|
||||
*
|
||||
* Copyright (c) 2020 Advanced Micro Devices, Inc.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in all
|
||||
* copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
* SOFTWARE.
|
||||
*
|
||||
*******************************************************************************/
|
||||
#ifndef CK_REDUCTION_ENUMS_HPP
|
||||
#define CK_REDUCTION_ENUMS_HPP
|
||||
|
||||
namespace ck {
|
||||
|
||||
enum class ReduceTensorOp_t
|
||||
{
|
||||
ADD = 0,
|
||||
MUL = 1,
|
||||
MIN = 2,
|
||||
MAX = 3,
|
||||
AMAX = 4,
|
||||
AVG = 5,
|
||||
NORM1 = 6,
|
||||
NORM2 = 7,
|
||||
// MUL_NO_ZEROS = 8,
|
||||
};
|
||||
|
||||
enum class NanPropagation_t
|
||||
{
|
||||
NOT_PROPAGATE_NAN = 0,
|
||||
PROPAGATE_NAN = 1,
|
||||
};
|
||||
|
||||
enum class ReduceTensorIndices_t
|
||||
{
|
||||
NO_INDICES = 0,
|
||||
FLATTENED_INDICES = 1,
|
||||
};
|
||||
|
||||
enum class IndicesType_t
|
||||
{
|
||||
INDICES_32BIT = 0,
|
||||
INDICES_64BIT = 1,
|
||||
INDICES_16BIT = 2,
|
||||
INDICES_8BIT = 3,
|
||||
};
|
||||
|
||||
}; // end of namespace ck
|
||||
|
||||
#endif
|
||||
@@ -35,10 +35,12 @@ namespace reduce {
|
||||
// Every binary operator used in reduction is represented by a templated functor class. Each functor
|
||||
// class must provide at least
|
||||
// three members:
|
||||
// 1) GetZeroVal() -- the interface to return the "identity element" for the binary operator,
|
||||
// "identity element" is the unique
|
||||
// 1) GetReductionZeroVal() -- the interface to return the "identity element" for the binary
|
||||
// operator, "identity element" is the unique
|
||||
// element in the algebraic space that doesn't affect the value of other elements
|
||||
// when operated with any of them.
|
||||
// when operated against them, and the concept is similar to zero vector in
|
||||
// vector space
|
||||
// (http://pages.cs.wisc.edu/~matthewb/pages/notes/pdf/linearalgebra/VectorSpaces.pdf).
|
||||
// 2) indexable -- boolean value indicating whether indices of the operated elements could be
|
||||
// recorded. Usually, Min/Max operator could
|
||||
// need to record the indices of elements. For operator like Add/Mul, no need to
|
||||
@@ -58,7 +60,7 @@ struct Add
|
||||
{
|
||||
using dataType = T;
|
||||
|
||||
__device__ static T GetZeroVal() { return type_convert<T>{}(0.0f); };
|
||||
__device__ static constexpr T GetReductionZeroVal() { return static_cast<T>(0.0f); };
|
||||
|
||||
__device__ inline constexpr void operator()(T& a, T b) const { a = a + b; }
|
||||
|
||||
@@ -70,7 +72,7 @@ struct Mul
|
||||
{
|
||||
using dataType = T;
|
||||
|
||||
__device__ static T GetZeroVal() { return type_convert<T>{}(1.0f); };
|
||||
__device__ static constexpr T GetReductionZeroVal() { return static_cast<T>(1.0f); };
|
||||
|
||||
__device__ inline constexpr void operator()(T& a, T b) const { a = a * b; }
|
||||
|
||||
@@ -82,7 +84,7 @@ struct Max
|
||||
{
|
||||
using dataType = T;
|
||||
|
||||
__device__ static T GetZeroVal() { return std::numeric_limits<T>::min(); };
|
||||
__device__ static constexpr T GetReductionZeroVal() { return NumericLimits<T>::Lowest(); };
|
||||
|
||||
__device__ inline constexpr void operator()(T& a, T b) const
|
||||
{
|
||||
@@ -107,7 +109,7 @@ struct Min
|
||||
{
|
||||
using dataType = T;
|
||||
|
||||
__device__ static T GetZeroVal() { return std::numeric_limits<T>::max(); };
|
||||
__device__ static constexpr T GetReductionZeroVal() { return NumericLimits<T>::Max(); };
|
||||
|
||||
__device__ inline constexpr void operator()(T& a, T b) const
|
||||
{
|
||||
@@ -127,16 +129,29 @@ struct Min
|
||||
static constexpr bool indexable = true;
|
||||
};
|
||||
|
||||
template <>
|
||||
__device__ half_t Max<half_t>::GetZeroVal()
|
||||
template <class T>
|
||||
struct AMax
|
||||
{
|
||||
return type_convert<half_t>{}(std::numeric_limits<float>::min());
|
||||
};
|
||||
using dataType = T;
|
||||
|
||||
template <>
|
||||
__device__ half_t Min<half_t>::GetZeroVal()
|
||||
{
|
||||
return type_convert<half_t>{}(std::numeric_limits<float>::max());
|
||||
__device__ static constexpr T GetReductionZeroVal() { return static_cast<T>(0.0f); };
|
||||
|
||||
__device__ inline constexpr void operator()(T& a, T b) const
|
||||
{
|
||||
if(a < b)
|
||||
a = b;
|
||||
}
|
||||
|
||||
__device__ inline constexpr void operator()(T& a, T b, bool& changed) const
|
||||
{
|
||||
if(a < b)
|
||||
{
|
||||
a = b;
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
|
||||
static constexpr bool indexable = true;
|
||||
};
|
||||
|
||||
// Unary operators are usually called element-wisely before the reduction is executed on the
|
||||
@@ -268,7 +283,7 @@ struct unary_sqrt<half_t>
|
||||
|
||||
// The templated struct reduce_binary_operator maps the enum Ids of binary operators to their
|
||||
// respective functor classes.
|
||||
// The "GetZeroVal()" interface and boolean member "indexable" are also provided in
|
||||
// The "GetReductionZeroVal()" interface and boolean member "indexable" are also provided in
|
||||
// reduce_binary_operactor for
|
||||
// easier checking by the upper-layer codes in the kernels.
|
||||
|
||||
@@ -281,8 +296,6 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::ADD>
|
||||
using opType = reduce::Add<T>;
|
||||
using dataType = T;
|
||||
|
||||
__device__ static T GetZeroVal() { return reduce::Add<T>::GetZeroVal(); };
|
||||
|
||||
static constexpr bool indexable = reduce::Add<T>::indexable;
|
||||
};
|
||||
|
||||
@@ -292,8 +305,6 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::MUL>
|
||||
using opType = reduce::Mul<T>;
|
||||
using dataType = T;
|
||||
|
||||
__device__ static T GetZeroVal() { return reduce::Mul<T>::GetZeroVal(); };
|
||||
|
||||
static constexpr bool indexable = reduce::Mul<T>::indexable;
|
||||
};
|
||||
|
||||
@@ -303,8 +314,6 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::MIN>
|
||||
using opType = reduce::Min<T>;
|
||||
using dataType = T;
|
||||
|
||||
__device__ static T GetZeroVal() { return reduce::Min<T>::GetZeroVal(); };
|
||||
|
||||
static constexpr bool indexable = reduce::Min<T>::indexable;
|
||||
};
|
||||
|
||||
@@ -314,19 +323,15 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::MAX>
|
||||
using opType = reduce::Max<T>;
|
||||
using dataType = T;
|
||||
|
||||
__device__ static T GetZeroVal() { return reduce::Max<T>::GetZeroVal(); };
|
||||
|
||||
static constexpr bool indexable = reduce::Max<T>::indexable;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct reduce_binary_operator<T, ReduceTensorOp_t::AMAX>
|
||||
{
|
||||
using opType = reduce::Max<T>;
|
||||
using opType = reduce::AMax<T>;
|
||||
using dataType = T;
|
||||
|
||||
__device__ static T GetZeroVal() { return reduce::Max<T>::GetZeroVal(); };
|
||||
|
||||
static constexpr bool indexable = reduce::Max<T>::indexable;
|
||||
};
|
||||
|
||||
@@ -336,8 +341,6 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::AVG>
|
||||
using opType = reduce::Add<T>;
|
||||
using dataType = T;
|
||||
|
||||
__device__ static T GetZeroVal() { return reduce::Add<T>::GetZeroVal(); };
|
||||
|
||||
static constexpr bool indexable = reduce::Add<T>::indexable;
|
||||
};
|
||||
|
||||
@@ -347,8 +350,6 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::NORM1>
|
||||
using opType = reduce::Add<T>;
|
||||
using dataType = T;
|
||||
|
||||
__device__ static T GetZeroVal() { return reduce::Add<T>::GetZeroVal(); };
|
||||
|
||||
static constexpr bool indexable = reduce::Add<T>::indexable;
|
||||
};
|
||||
|
||||
@@ -358,8 +359,6 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::NORM2>
|
||||
using opType = reduce::Add<T>;
|
||||
using dataType = T;
|
||||
|
||||
__device__ static T GetZeroVal() { return reduce::Add<T>::GetZeroVal(); };
|
||||
|
||||
static constexpr bool indexable = reduce::Add<T>::indexable;
|
||||
};
|
||||
|
||||
|
||||
@@ -43,9 +43,6 @@ using compType =
|
||||
constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable
|
||||
|
||||
constexpr index_t srcDims = CK_PARAM_IN_DIMS;
|
||||
constexpr index_t dstDims = CK_PARAM_OUT_DIMS;
|
||||
|
||||
using toReduceDims = Sequence<CK_PARAM_TOREDUCE_DIMS>;
|
||||
|
||||
constexpr ReduceTensorOp_t op = static_cast<ReduceTensorOp_t>(CK_PARAM_REDUCE_OP);
|
||||
constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0
|
||||
@@ -58,16 +55,6 @@ constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0
|
||||
constexpr bool src2d_need_padding = static_cast<bool>(CK_PARAM_SRC2D_PADDING);
|
||||
constexpr bool dst1d_need_padding = static_cast<bool>(CK_PARAM_DST1D_PADDING);
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////
|
||||
using specDims = typename sequence_merge<Sequence<>, toReduceDims>::type;
|
||||
|
||||
static_assert(is_valid_sequence_map<specDims>::value && specDims::Size() == srcDims,
|
||||
"Wrong invariant and/or toReduce dimensions!");
|
||||
|
||||
// The number of invariant dimensions can be zero if all dimension are to be reduced
|
||||
static_assert(dstDims == 1,
|
||||
"If all source dimensions are reduced, the dest should have only one dimension !!");
|
||||
|
||||
constexpr bool indexable = reduce_binary_operator<compType, op>::indexable;
|
||||
constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES);
|
||||
|
||||
@@ -110,18 +97,6 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize,
|
||||
int inStride3,
|
||||
int inStride4,
|
||||
int inStride5,
|
||||
int outLength0,
|
||||
int outLength1,
|
||||
int outLength2,
|
||||
int outLength3,
|
||||
int outLength4,
|
||||
int outLength5,
|
||||
int outStride0,
|
||||
int outStride1,
|
||||
int outStride2,
|
||||
int outStride3,
|
||||
int outStride4,
|
||||
int outStride5,
|
||||
void* __restrict__ ws_global)
|
||||
{
|
||||
(void)GridSize;
|
||||
@@ -132,18 +107,14 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize,
|
||||
|
||||
const int srcLengths[6] = {inLength0, inLength1, inLength2, inLength3, inLength4, inLength5};
|
||||
const int srcStrides[6] = {inStride0, inStride1, inStride2, inStride3, inStride4, inStride5};
|
||||
const int dstLengths[6] = {
|
||||
outLength0, outLength1, outLength2, outLength3, outLength4, outLength5};
|
||||
const int dstStrides[6] = {
|
||||
outStride0, outStride1, outStride2, outStride3, outStride4, outStride5};
|
||||
|
||||
const auto tupleSrcLengths = make_tuple_from_array(srcLengths, Number<srcDims>{});
|
||||
const auto tupleSrcStrides = make_tuple_from_array(srcStrides, Number<srcDims>{});
|
||||
const auto tupleDstLengths = make_tuple_from_array(dstLengths, Number<dstDims>{});
|
||||
const auto tupleDstStrides = make_tuple_from_array(dstStrides, Number<dstDims>{});
|
||||
const auto tupleDstLengths = make_tuple(1);
|
||||
const auto tupleDstStrides = make_tuple(1);
|
||||
|
||||
const auto srcDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
|
||||
const auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
|
||||
auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
|
||||
|
||||
const auto one_dim_srcDesc = transform_tensor_descriptor(
|
||||
srcDesc,
|
||||
@@ -157,14 +128,8 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize,
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 1>{}));
|
||||
|
||||
auto dst1dDesc = transform_tensor_descriptor(
|
||||
dstDesc,
|
||||
make_tuple(make_merge_transform(tupleDstLengths)),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto invariantLen = src2dDesc.GetLength(Number<0>{});
|
||||
const auto toReduceLen = src2dDesc.GetLength(Number<1>{});
|
||||
constexpr int invariantLen = 1;
|
||||
const auto toReduceLen = src2dDesc.GetLength(Number<1>{});
|
||||
|
||||
constexpr auto copySliceLen = BlockSize * GredAccessesPerThreadInBlock;
|
||||
|
||||
@@ -179,30 +144,28 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize,
|
||||
make_pad_transform(toReduceLen, 0, srcPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
if(hipThreadIdx_x == 0)
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(src2dDesc_2)*>(p_src2dDesc) = src2dDesc_2;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(hipThreadIdx_x == 0)
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(src2dDesc)*>(p_src2dDesc) = src2dDesc;
|
||||
}
|
||||
|
||||
if(hipThreadIdx_x == 0)
|
||||
*static_cast<decltype(dst1dDesc)*>(p_dst1dDesc) = dst1dDesc;
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(dstDesc)*>(p_dst1dDesc) = dstDesc;
|
||||
};
|
||||
|
||||
template <index_t srcDims, index_t dstDims, typename invariantDims, typename toReduceDims>
|
||||
template <index_t srcDims>
|
||||
struct get_ref_desc_types
|
||||
{
|
||||
static constexpr auto ref_srcLengths = typename uniform_sequence_gen<srcDims, 8>::type{};
|
||||
static constexpr auto ref_dstLengths = typename uniform_sequence_gen<dstDims, 1>::type{};
|
||||
|
||||
// don't have to use accurate strides to get an expected referrence type
|
||||
static constexpr auto ref_srcDesc = make_naive_tensor_descriptor(
|
||||
make_tuple_from_seq(ref_srcLengths), make_tuple_from_seq(ref_srcLengths));
|
||||
static constexpr auto ref_dstDesc = make_naive_tensor_descriptor(
|
||||
make_tuple_from_seq(ref_dstLengths), make_tuple_from_seq(ref_dstLengths));
|
||||
static constexpr auto ref_dstDesc = make_naive_tensor_descriptor(make_tuple(1), make_tuple(1));
|
||||
|
||||
static constexpr auto ref_one_dim_srcDesc = transform_tensor_descriptor(
|
||||
ref_srcDesc,
|
||||
@@ -217,12 +180,6 @@ struct get_ref_desc_types
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 1>{}));
|
||||
|
||||
static constexpr auto ref_dst1dDesc = transform_tensor_descriptor(
|
||||
ref_dstDesc,
|
||||
make_tuple(make_merge_transform(make_tuple_from_seq(ref_dstLengths))),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
static constexpr auto ref_invariantLen = ref_src2dDesc.GetLength(Number<0>{});
|
||||
static constexpr auto ref_toReduceLen = ref_src2dDesc.GetLength(Number<1>{});
|
||||
|
||||
@@ -235,25 +192,22 @@ struct get_ref_desc_types
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{})));
|
||||
|
||||
using refType_dst1dDesc_padded =
|
||||
decltype(transform_tensor_descriptor(ref_dst1dDesc,
|
||||
decltype(transform_tensor_descriptor(ref_dstDesc,
|
||||
make_tuple(make_pad_transform(ref_invariantLen, 0, 2)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{})));
|
||||
|
||||
using refType_src2dDesc = decltype(ref_src2dDesc);
|
||||
using refType_dst1dDesc = decltype(ref_dst1dDesc);
|
||||
using refType_dst1dDesc = decltype(ref_dstDesc);
|
||||
};
|
||||
|
||||
using refType_src2dDesc =
|
||||
typename get_ref_desc_types<srcDims, dstDims, toReduceDims>::refType_src2dDesc;
|
||||
using refType_dst1dDesc =
|
||||
typename get_ref_desc_types<srcDims, dstDims, toReduceDims>::refType_dst1dDesc;
|
||||
using refType_src2dDesc = typename get_ref_desc_types<srcDims>::refType_src2dDesc;
|
||||
using refType_dst1dDesc = typename get_ref_desc_types<srcDims>::refType_dst1dDesc;
|
||||
using refType_src2dDesc_padded_34 =
|
||||
typename get_ref_desc_types<srcDims, dstDims, toReduceDims>::refType_src2dDesc_padded_34;
|
||||
using refType_dst1dDesc_padded =
|
||||
typename get_ref_desc_types<srcDims, dstDims, toReduceDims>::refType_dst1dDesc_padded;
|
||||
typename get_ref_desc_types<srcDims>::refType_src2dDesc_padded_34;
|
||||
using refType_dst1dDesc_padded = typename get_ref_desc_types<srcDims>::refType_dst1dDesc_padded;
|
||||
|
||||
template <ReductionMethod_t impl, bool need_padding>
|
||||
template <bool need_padding>
|
||||
static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc)
|
||||
{
|
||||
if constexpr(need_padding)
|
||||
@@ -277,15 +231,15 @@ extern "C" __global__ void gridwise_generic_reduce_1(int origReduceLen,
|
||||
const void* __restrict__ p_src_global,
|
||||
float beta,
|
||||
void* __restrict__ p_dst_global,
|
||||
void* __restrict__ ws_global,
|
||||
const void CONSTANT* ws_global,
|
||||
long ws_buf2_bytes_offset,
|
||||
void* __restrict__ indices_global)
|
||||
{
|
||||
(void)BlkGroupSize;
|
||||
(void)ws_buf2_bytes_offset;
|
||||
|
||||
const void* p_src2dDesc = ws_global;
|
||||
const void* p_dst1dDesc = static_cast<char*>(ws_global) + 2048;
|
||||
const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global);
|
||||
const void* p_dst1dDesc = static_cast<const char*>(p_src2dDesc) + 2048;
|
||||
|
||||
const auto src2dDesc = get_reduction_src2d_descriptor<src2d_need_padding>(p_src2dDesc);
|
||||
const auto dst1dDesc = get_reduction_dst1d_descriptor<dst1d_need_padding>(p_dst1dDesc);
|
||||
|
||||
@@ -45,8 +45,11 @@ constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable
|
||||
constexpr index_t srcDims = CK_PARAM_IN_DIMS;
|
||||
constexpr index_t dstDims = CK_PARAM_OUT_DIMS;
|
||||
|
||||
using toReduceDims = Sequence<CK_PARAM_TOREDUCE_DIMS>;
|
||||
using invariantDims = Sequence<CK_PARAM_INVARIANT_DIMS>;
|
||||
constexpr index_t num_toReduceDims = CK_PARAM_NUM_TOREDUCE_DIMS;
|
||||
constexpr index_t num_invariantDims = srcDims - num_toReduceDims;
|
||||
|
||||
using invariantDims = typename arithmetic_sequence_gen<0, num_invariantDims, 1>::type;
|
||||
using toReduceDims = typename arithmetic_sequence_gen<num_invariantDims, srcDims, 1>::type;
|
||||
|
||||
constexpr ReduceTensorOp_t op = static_cast<ReduceTensorOp_t>(CK_PARAM_REDUCE_OP);
|
||||
constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0
|
||||
@@ -59,15 +62,7 @@ constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0
|
||||
constexpr bool src2d_need_padding = static_cast<bool>(CK_PARAM_SRC2D_PADDING);
|
||||
constexpr bool dst1d_need_padding = static_cast<bool>(CK_PARAM_DST1D_PADDING);
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////
|
||||
using specDims = typename sequence_merge<invariantDims, toReduceDims>::type;
|
||||
|
||||
static_assert(is_valid_sequence_map<specDims>::value && specDims::Size() == srcDims,
|
||||
"Wrong invariant and/or toReduce dimensions!");
|
||||
|
||||
// The number of invariant dimensions can be zero if all dimension are to be reduced
|
||||
static_assert(invariantDims::Size() > 0 || dstDims == 1,
|
||||
"If all source dimensions are reduced, the dest should have only one dimension !!");
|
||||
static_assert(num_invariantDims > 0, "Not all dimensins are reduced for this kernel !!");
|
||||
|
||||
constexpr bool indexable = reduce_binary_operator<compType, op>::indexable;
|
||||
constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES);
|
||||
@@ -111,12 +106,6 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize,
|
||||
int inStride3,
|
||||
int inStride4,
|
||||
int inStride5,
|
||||
int outLength0,
|
||||
int outLength1,
|
||||
int outLength2,
|
||||
int outLength3,
|
||||
int outLength4,
|
||||
int outLength5,
|
||||
int outStride0,
|
||||
int outStride1,
|
||||
int outStride2,
|
||||
@@ -133,14 +122,12 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize,
|
||||
|
||||
const int srcLengths[6] = {inLength0, inLength1, inLength2, inLength3, inLength4, inLength5};
|
||||
const int srcStrides[6] = {inStride0, inStride1, inStride2, inStride3, inStride4, inStride5};
|
||||
const int dstLengths[6] = {
|
||||
outLength0, outLength1, outLength2, outLength3, outLength4, outLength5};
|
||||
const int dstStrides[6] = {
|
||||
outStride0, outStride1, outStride2, outStride3, outStride4, outStride5};
|
||||
|
||||
const auto tupleSrcLengths = make_tuple_from_array(srcLengths, Number<srcDims>{});
|
||||
const auto tupleSrcStrides = make_tuple_from_array(srcStrides, Number<srcDims>{});
|
||||
const auto tupleDstLengths = make_tuple_from_array(dstLengths, Number<dstDims>{});
|
||||
const auto tupleDstLengths = make_tuple_from_array(srcLengths, Number<dstDims>{});
|
||||
const auto tupleDstStrides = make_tuple_from_array(dstStrides, Number<dstDims>{});
|
||||
|
||||
const auto srcDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
|
||||
@@ -179,16 +166,16 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize,
|
||||
make_pad_transform(toReduceLen, 0, srcPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
if(hipThreadIdx_x == 0)
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(src2dDesc_2)*>(p_src2dDesc) = src2dDesc_2;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(hipThreadIdx_x == 0)
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(src2dDesc)*>(p_src2dDesc) = src2dDesc;
|
||||
}
|
||||
|
||||
if(hipThreadIdx_x == 0)
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(dst1dDesc)*>(p_dst1dDesc) = dst1dDesc;
|
||||
};
|
||||
|
||||
@@ -278,15 +265,15 @@ extern "C" __global__ void gridwise_generic_reduce_1(int origReduceLen,
|
||||
const void* __restrict__ p_src_global,
|
||||
float beta,
|
||||
void* __restrict__ p_dst_global,
|
||||
void* __restrict__ ws_global,
|
||||
const void CONSTANT* ws_global,
|
||||
long ws_buf2_bytes_offset,
|
||||
void* __restrict__ indices_global)
|
||||
{
|
||||
(void)BlkGroupSize;
|
||||
(void)ws_buf2_bytes_offset;
|
||||
|
||||
const void* p_src2dDesc = ws_global;
|
||||
const void* p_dst1dDesc = static_cast<char*>(ws_global) + 2048;
|
||||
const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global);
|
||||
const void* p_dst1dDesc = static_cast<const char*>(p_src2dDesc) + 2048;
|
||||
|
||||
const auto src2dDesc = get_reduction_src2d_descriptor<src2d_need_padding>(p_src2dDesc);
|
||||
const auto dst1dDesc = get_reduction_dst1d_descriptor<dst1d_need_padding>(p_dst1dDesc);
|
||||
|
||||
@@ -43,10 +43,6 @@ using compType =
|
||||
constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable
|
||||
|
||||
constexpr index_t srcDims = CK_PARAM_IN_DIMS;
|
||||
constexpr index_t dstDims = CK_PARAM_OUT_DIMS;
|
||||
|
||||
using toReduceDims = Sequence<CK_PARAM_TOREDUCE_DIMS>;
|
||||
using invariantDims = Sequence<CK_PARAM_INVARIANT_DIMS>; // this could be empty
|
||||
|
||||
constexpr ReduceTensorOp_t op = static_cast<ReduceTensorOp_t>(CK_PARAM_REDUCE_OP);
|
||||
constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0
|
||||
@@ -59,16 +55,6 @@ constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0
|
||||
constexpr bool src2d_need_padding = static_cast<bool>(CK_PARAM_SRC2D_PADDING);
|
||||
constexpr bool dst1d_need_padding = static_cast<bool>(CK_PARAM_DST1D_PADDING);
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////
|
||||
using specDims = typename sequence_merge<Sequence<>, toReduceDims>::type;
|
||||
|
||||
static_assert(is_valid_sequence_map<specDims>::value && specDims::Size() == srcDims,
|
||||
"Wrong invariant and/or toReduce dimensions!");
|
||||
|
||||
// The number of invariant dimensions can be zero if all dimension are to be reduced
|
||||
static_assert(dstDims == 1,
|
||||
"If all source dimensions are reduced, the dest should have only one dimension !!");
|
||||
|
||||
constexpr bool indexable = reduce_binary_operator<compType, op>::indexable;
|
||||
constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES);
|
||||
|
||||
@@ -111,18 +97,6 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize,
|
||||
int inStride3,
|
||||
int inStride4,
|
||||
int inStride5,
|
||||
int outLength0,
|
||||
int outLength1,
|
||||
int outLength2,
|
||||
int outLength3,
|
||||
int outLength4,
|
||||
int outLength5,
|
||||
int outStride0,
|
||||
int outStride1,
|
||||
int outStride2,
|
||||
int outStride3,
|
||||
int outStride4,
|
||||
int outStride5,
|
||||
void* __restrict__ ws_global)
|
||||
{
|
||||
(void)GridSize;
|
||||
@@ -132,18 +106,14 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize,
|
||||
|
||||
const int srcLengths[6] = {inLength0, inLength1, inLength2, inLength3, inLength4, inLength5};
|
||||
const int srcStrides[6] = {inStride0, inStride1, inStride2, inStride3, inStride4, inStride5};
|
||||
const int dstLengths[6] = {
|
||||
outLength0, outLength1, outLength2, outLength3, outLength4, outLength5};
|
||||
const int dstStrides[6] = {
|
||||
outStride0, outStride1, outStride2, outStride3, outStride4, outStride5};
|
||||
|
||||
const auto tupleSrcLengths = make_tuple_from_array(srcLengths, Number<srcDims>{});
|
||||
const auto tupleSrcStrides = make_tuple_from_array(srcStrides, Number<srcDims>{});
|
||||
const auto tupleDstLengths = make_tuple_from_array(dstLengths, Number<dstDims>{});
|
||||
const auto tupleDstStrides = make_tuple_from_array(dstStrides, Number<dstDims>{});
|
||||
const auto tupleDstLengths = make_tuple(1);
|
||||
const auto tupleDstStrides = make_tuple(1);
|
||||
|
||||
const auto srcDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
|
||||
const auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
|
||||
auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
|
||||
|
||||
const auto one_dim_srcDesc = transform_tensor_descriptor(
|
||||
srcDesc,
|
||||
@@ -157,14 +127,8 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize,
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 1>{}));
|
||||
|
||||
auto dst1dDesc = transform_tensor_descriptor(
|
||||
dstDesc,
|
||||
make_tuple(make_merge_transform(tupleDstLengths)),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto invariantLen = src2dDesc.GetLength(Number<0>{});
|
||||
const auto toReduceLen = src2dDesc.GetLength(Number<1>{});
|
||||
constexpr int invariantLen = 1;
|
||||
const auto toReduceLen = src2dDesc.GetLength(Number<1>{});
|
||||
|
||||
constexpr auto copySliceLen = BlockSize * GredAccessesPerThreadInBlock;
|
||||
const index_t reduceSizePerBlock =
|
||||
@@ -181,30 +145,28 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize,
|
||||
make_pad_transform(toReduceLen, 0, srcPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
if(hipThreadIdx_x == 0)
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(src2dDesc_2)*>(p_src2dDesc) = src2dDesc_2;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(hipThreadIdx_x == 0)
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(src2dDesc)*>(p_src2dDesc) = src2dDesc;
|
||||
}
|
||||
|
||||
if(hipThreadIdx_x == 0)
|
||||
*static_cast<decltype(dst1dDesc)*>(p_dst1dDesc) = dst1dDesc;
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(dstDesc)*>(p_dst1dDesc) = dstDesc;
|
||||
};
|
||||
|
||||
template <index_t srcDims, index_t dstDims, typename toReduceDims>
|
||||
template <index_t srcDims>
|
||||
struct get_ref_desc_types
|
||||
{
|
||||
static constexpr auto ref_srcLengths = typename uniform_sequence_gen<srcDims, 8>::type{};
|
||||
static constexpr auto ref_dstLengths = typename uniform_sequence_gen<dstDims, 1>::type{};
|
||||
|
||||
// don't have to use accurate strides to get an expected referrence type
|
||||
static constexpr auto ref_srcDesc = make_naive_tensor_descriptor(
|
||||
make_tuple_from_seq(ref_srcLengths), make_tuple_from_seq(ref_srcLengths));
|
||||
static constexpr auto ref_dstDesc = make_naive_tensor_descriptor(
|
||||
make_tuple_from_seq(ref_dstLengths), make_tuple_from_seq(ref_dstLengths));
|
||||
static constexpr auto ref_dstDesc = make_naive_tensor_descriptor(make_tuple(1), make_tuple(1));
|
||||
|
||||
static constexpr auto ref_one_dim_srcDesc = transform_tensor_descriptor(
|
||||
ref_srcDesc,
|
||||
@@ -219,12 +181,6 @@ struct get_ref_desc_types
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 1>{}));
|
||||
|
||||
static constexpr auto ref_dst1dDesc = transform_tensor_descriptor(
|
||||
ref_dstDesc,
|
||||
make_tuple(make_merge_transform(make_tuple_from_seq(ref_dstLengths))),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
static constexpr auto ref_invariantLen = ref_src2dDesc.GetLength(Number<0>{});
|
||||
static constexpr auto ref_toReduceLen = ref_src2dDesc.GetLength(Number<1>{});
|
||||
|
||||
@@ -237,23 +193,20 @@ struct get_ref_desc_types
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{})));
|
||||
|
||||
using refType_dst1dDesc_padded =
|
||||
decltype(transform_tensor_descriptor(ref_dst1dDesc,
|
||||
decltype(transform_tensor_descriptor(ref_dstDesc,
|
||||
make_tuple(make_pad_transform(ref_invariantLen, 0, 2)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{})));
|
||||
|
||||
using refType_src2dDesc = decltype(ref_src2dDesc);
|
||||
using refType_dst1dDesc = decltype(ref_dst1dDesc);
|
||||
using refType_dst1dDesc = decltype(ref_dstDesc);
|
||||
};
|
||||
|
||||
using refType_src2dDesc =
|
||||
typename get_ref_desc_types<srcDims, dstDims, toReduceDims>::refType_src2dDesc;
|
||||
using refType_dst1dDesc =
|
||||
typename get_ref_desc_types<srcDims, dstDims, toReduceDims>::refType_dst1dDesc;
|
||||
using refType_src2dDesc = typename get_ref_desc_types<srcDims>::refType_src2dDesc;
|
||||
using refType_dst1dDesc = typename get_ref_desc_types<srcDims>::refType_dst1dDesc;
|
||||
using refType_src2dDesc_padded_34 =
|
||||
typename get_ref_desc_types<srcDims, dstDims, toReduceDims>::refType_src2dDesc_padded_34;
|
||||
using refType_dst1dDesc_padded =
|
||||
typename get_ref_desc_types<srcDims, dstDims, toReduceDims>::refType_dst1dDesc_padded;
|
||||
typename get_ref_desc_types<srcDims>::refType_src2dDesc_padded_34;
|
||||
using refType_dst1dDesc_padded = typename get_ref_desc_types<srcDims>::refType_dst1dDesc_padded;
|
||||
|
||||
template <bool need_padding>
|
||||
static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc)
|
||||
@@ -279,16 +232,16 @@ extern "C" __global__ void gridwise_generic_reduce_1(int origReduceLen,
|
||||
const void* __restrict__ p_src_global,
|
||||
float beta,
|
||||
void* __restrict__ p_dst_global,
|
||||
void* __restrict__ ws_global,
|
||||
const void CONSTANT* ws_global,
|
||||
long ws_buf2_bytes_offset,
|
||||
void* __restrict__ indices_global)
|
||||
{
|
||||
(void)p_dst_global;
|
||||
(void)indices_global;
|
||||
|
||||
const void* p_src2dDesc = ws_global;
|
||||
const void* p_dst1dDesc = static_cast<char*>(ws_global) + 2048;
|
||||
void* ws_buf1_global = static_cast<char*>(ws_global) + 4096;
|
||||
const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global);
|
||||
const void* p_dst1dDesc = static_cast<const char*>(p_src2dDesc) + 2048;
|
||||
void* ws_buf1_global = const_cast<char*>(static_cast<const char*>(p_src2dDesc) + 4096);
|
||||
|
||||
const auto src2dDesc = get_reduction_src2d_descriptor<src2d_need_padding>(p_src2dDesc);
|
||||
const auto dst1dDesc = get_reduction_dst1d_descriptor<dst1d_need_padding>(p_dst1dDesc);
|
||||
|
||||
@@ -45,8 +45,11 @@ constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable
|
||||
constexpr index_t srcDims = CK_PARAM_IN_DIMS;
|
||||
constexpr index_t dstDims = CK_PARAM_OUT_DIMS;
|
||||
|
||||
using toReduceDims = Sequence<CK_PARAM_TOREDUCE_DIMS>;
|
||||
using invariantDims = Sequence<CK_PARAM_INVARIANT_DIMS>;
|
||||
constexpr index_t num_toReduceDims = CK_PARAM_NUM_TOREDUCE_DIMS;
|
||||
constexpr index_t num_invariantDims = srcDims - num_toReduceDims;
|
||||
|
||||
using invariantDims = typename arithmetic_sequence_gen<0, num_invariantDims, 1>::type;
|
||||
using toReduceDims = typename arithmetic_sequence_gen<num_invariantDims, srcDims, 1>::type;
|
||||
|
||||
constexpr ReduceTensorOp_t op = static_cast<ReduceTensorOp_t>(CK_PARAM_REDUCE_OP);
|
||||
constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0
|
||||
@@ -59,15 +62,7 @@ constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0
|
||||
constexpr bool src2d_need_padding = static_cast<bool>(CK_PARAM_SRC2D_PADDING);
|
||||
constexpr bool dst1d_need_padding = static_cast<bool>(CK_PARAM_DST1D_PADDING);
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////
|
||||
using specDims = typename sequence_merge<invariantDims, toReduceDims>::type;
|
||||
|
||||
static_assert(is_valid_sequence_map<specDims>::value && specDims::Size() == srcDims,
|
||||
"Wrong invariant and/or toReduce dimensions!");
|
||||
|
||||
// The number of invariant dimensions can be zero if all dimension are to be reduced
|
||||
static_assert(invariantDims::Size() > 0 || dstDims == 1,
|
||||
"If all source dimensions are reduced, the dest should have only one dimension !!");
|
||||
static_assert(num_invariantDims > 0, "Not all dimensins are reduced for this kernel !!");
|
||||
|
||||
constexpr bool indexable = reduce_binary_operator<compType, op>::indexable;
|
||||
constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES);
|
||||
@@ -111,12 +106,6 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize,
|
||||
int inStride3,
|
||||
int inStride4,
|
||||
int inStride5,
|
||||
int outLength0,
|
||||
int outLength1,
|
||||
int outLength2,
|
||||
int outLength3,
|
||||
int outLength4,
|
||||
int outLength5,
|
||||
int outStride0,
|
||||
int outStride1,
|
||||
int outStride2,
|
||||
@@ -132,14 +121,12 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize,
|
||||
|
||||
const int srcLengths[6] = {inLength0, inLength1, inLength2, inLength3, inLength4, inLength5};
|
||||
const int srcStrides[6] = {inStride0, inStride1, inStride2, inStride3, inStride4, inStride5};
|
||||
const int dstLengths[6] = {
|
||||
outLength0, outLength1, outLength2, outLength3, outLength4, outLength5};
|
||||
const int dstStrides[6] = {
|
||||
outStride0, outStride1, outStride2, outStride3, outStride4, outStride5};
|
||||
|
||||
const auto tupleSrcLengths = make_tuple_from_array(srcLengths, Number<srcDims>{});
|
||||
const auto tupleSrcStrides = make_tuple_from_array(srcStrides, Number<srcDims>{});
|
||||
const auto tupleDstLengths = make_tuple_from_array(dstLengths, Number<dstDims>{});
|
||||
const auto tupleDstLengths = make_tuple_from_array(srcLengths, Number<dstDims>{});
|
||||
const auto tupleDstStrides = make_tuple_from_array(dstStrides, Number<dstDims>{});
|
||||
|
||||
const auto srcDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
|
||||
@@ -180,16 +167,16 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize,
|
||||
make_pad_transform(toReduceLen, 0, srcPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
if(hipThreadIdx_x == 0)
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(src2dDesc_2)*>(p_src2dDesc) = src2dDesc_2;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(hipThreadIdx_x == 0)
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(src2dDesc)*>(p_src2dDesc) = src2dDesc;
|
||||
}
|
||||
|
||||
if(hipThreadIdx_x == 0)
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(dst1dDesc)*>(p_dst1dDesc) = dst1dDesc;
|
||||
};
|
||||
|
||||
@@ -279,16 +266,16 @@ extern "C" __global__ void gridwise_generic_reduce_1(int origReduceLen,
|
||||
const void* __restrict__ p_src_global,
|
||||
float beta,
|
||||
void* __restrict__ p_dst_global,
|
||||
void* __restrict__ ws_global,
|
||||
const void CONSTANT* ws_global,
|
||||
long ws_buf2_bytes_offset,
|
||||
void* __restrict__ indices_global)
|
||||
{
|
||||
(void)p_dst_global;
|
||||
(void)indices_global;
|
||||
|
||||
const void* p_src2dDesc = ws_global;
|
||||
const void* p_dst1dDesc = static_cast<char*>(ws_global) + 2048;
|
||||
void* ws_buf1_global = static_cast<char*>(ws_global) + 4096;
|
||||
const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global);
|
||||
const void* p_dst1dDesc = static_cast<const char*>(p_src2dDesc) + 2048;
|
||||
void* ws_buf1_global = const_cast<char*>(static_cast<const char*>(p_src2dDesc) + 4096);
|
||||
|
||||
const auto src2dDesc = get_reduction_src2d_descriptor<src2d_need_padding>(p_src2dDesc);
|
||||
const auto dst1dDesc = get_reduction_dst1d_descriptor<dst1d_need_padding>(p_dst1dDesc);
|
||||
|
||||
@@ -43,9 +43,6 @@ using compType =
|
||||
constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable
|
||||
|
||||
constexpr index_t srcDims = CK_PARAM_IN_DIMS;
|
||||
constexpr index_t dstDims = CK_PARAM_OUT_DIMS;
|
||||
|
||||
using toReduceDims = Sequence<CK_PARAM_TOREDUCE_DIMS>;
|
||||
|
||||
constexpr ReduceTensorOp_t op = static_cast<ReduceTensorOp_t>(CK_PARAM_REDUCE_OP);
|
||||
constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0
|
||||
@@ -58,16 +55,6 @@ constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0
|
||||
constexpr bool src2d_need_padding = static_cast<bool>(CK_PARAM_SRC2D_PADDING);
|
||||
constexpr bool dst1d_need_padding = static_cast<bool>(CK_PARAM_DST1D_PADDING);
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////
|
||||
using specDims = typename sequence_merge<Sequence<>, toReduceDims>::type;
|
||||
|
||||
static_assert(is_valid_sequence_map<specDims>::value && specDims::Size() == srcDims,
|
||||
"Wrong invariant and/or toReduce dimensions!");
|
||||
|
||||
// The number of invariant dimensions can be zero if all dimension are to be reduced
|
||||
static_assert(dstDims == 1,
|
||||
"If all source dimensions are reduced, the dest should have only one dimension !!");
|
||||
|
||||
constexpr bool indexable = reduce_binary_operator<compType, op>::indexable;
|
||||
constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES);
|
||||
|
||||
@@ -110,18 +97,6 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize,
|
||||
int inStride3,
|
||||
int inStride4,
|
||||
int inStride5,
|
||||
int outLength0,
|
||||
int outLength1,
|
||||
int outLength2,
|
||||
int outLength3,
|
||||
int outLength4,
|
||||
int outLength5,
|
||||
int outStride0,
|
||||
int outStride1,
|
||||
int outStride2,
|
||||
int outStride3,
|
||||
int outStride4,
|
||||
int outStride5,
|
||||
void* __restrict__ ws_global)
|
||||
{
|
||||
(void)BlkGroupSize;
|
||||
@@ -131,18 +106,14 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize,
|
||||
|
||||
const int srcLengths[6] = {inLength0, inLength1, inLength2, inLength3, inLength4, inLength5};
|
||||
const int srcStrides[6] = {inStride0, inStride1, inStride2, inStride3, inStride4, inStride5};
|
||||
const int dstLengths[6] = {
|
||||
outLength0, outLength1, outLength2, outLength3, outLength4, outLength5};
|
||||
const int dstStrides[6] = {
|
||||
outStride0, outStride1, outStride2, outStride3, outStride4, outStride5};
|
||||
|
||||
const auto tupleSrcLengths = make_tuple_from_array(srcLengths, Number<srcDims>{});
|
||||
const auto tupleSrcStrides = make_tuple_from_array(srcStrides, Number<srcDims>{});
|
||||
const auto tupleDstLengths = make_tuple_from_array(dstLengths, Number<dstDims>{});
|
||||
const auto tupleDstStrides = make_tuple_from_array(dstStrides, Number<dstDims>{});
|
||||
const auto tupleDstLengths = make_tuple(1);
|
||||
const auto tupleDstStrides = make_tuple(1);
|
||||
|
||||
const auto srcDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
|
||||
const auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
|
||||
auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
|
||||
|
||||
const auto one_dim_srcDesc = transform_tensor_descriptor(
|
||||
srcDesc,
|
||||
@@ -156,14 +127,8 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize,
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 1>{}));
|
||||
|
||||
auto dst1dDesc = transform_tensor_descriptor(
|
||||
dstDesc,
|
||||
make_tuple(make_merge_transform(tupleDstLengths)),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto invariantLen = src2dDesc.GetLength(Number<0>{});
|
||||
const auto toReduceLen = src2dDesc.GetLength(Number<1>{});
|
||||
constexpr int invariantLen = 1;
|
||||
const auto toReduceLen = src2dDesc.GetLength(Number<1>{});
|
||||
|
||||
constexpr auto copySliceLen = GredThreadBufferLength;
|
||||
|
||||
@@ -178,12 +143,12 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize,
|
||||
make_pad_transform(toReduceLen, 0, srcPad2)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
if(hipThreadIdx_x == 0)
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(src2dDesc_2)*>(p_src2dDesc) = src2dDesc_2;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(hipThreadIdx_x == 0)
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(src2dDesc)*>(p_src2dDesc) = src2dDesc;
|
||||
}
|
||||
|
||||
@@ -191,31 +156,29 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize,
|
||||
{
|
||||
const auto dstPad = GridSize * BlockSize - invariantLen;
|
||||
auto dst1dDesc_2 =
|
||||
transform_tensor_descriptor(dst1dDesc,
|
||||
transform_tensor_descriptor(dstdDesc,
|
||||
make_tuple(make_pad_transform(invariantLen, 0, dstPad)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
if(hipThreadIdx_x == 0)
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(dst1dDesc_2)*>(p_dst1dDesc) = dst1dDesc_2;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(hipThreadIdx_x == 0)
|
||||
*static_cast<decltype(dst1dDesc)*>(p_dst1dDesc) = dst1dDesc;
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(dstDesc)*>(p_dst1dDesc) = dstDesc;
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t srcDims, index_t dstDims, typename toReduceDims>
|
||||
template <index_t srcDims>
|
||||
struct get_ref_desc_types
|
||||
{
|
||||
static constexpr auto ref_srcLengths = typename uniform_sequence_gen<srcDims, 8>::type{};
|
||||
static constexpr auto ref_dstLengths = typename uniform_sequence_gen<dstDims, 1>::type{};
|
||||
|
||||
// don't have to use accurate strides to get an expected referrence type
|
||||
static constexpr auto ref_srcDesc = make_naive_tensor_descriptor(
|
||||
make_tuple_from_seq(ref_srcLengths), make_tuple_from_seq(ref_srcLengths));
|
||||
static constexpr auto ref_dstDesc = make_naive_tensor_descriptor(
|
||||
make_tuple_from_seq(ref_dstLengths), make_tuple_from_seq(ref_dstLengths));
|
||||
static constexpr auto ref_dstDesc = make_naive_tensor_descriptor(make_tuple(1), make_tuple(1));
|
||||
|
||||
static constexpr auto ref_one_dim_srcDesc = transform_tensor_descriptor(
|
||||
ref_srcDesc,
|
||||
@@ -230,12 +193,6 @@ struct get_ref_desc_types
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 1>{}));
|
||||
|
||||
static constexpr auto ref_dst1dDesc = transform_tensor_descriptor(
|
||||
ref_dstDesc,
|
||||
make_tuple(make_merge_transform(make_tuple_from_seq(ref_dstLengths))),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
static constexpr auto ref_invariantLen = ref_src2dDesc.GetLength(Number<0>{});
|
||||
static constexpr auto ref_toReduceLen = ref_src2dDesc.GetLength(Number<1>{});
|
||||
|
||||
@@ -248,23 +205,20 @@ struct get_ref_desc_types
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{})));
|
||||
|
||||
using refType_dst1dDesc_padded =
|
||||
decltype(transform_tensor_descriptor(ref_dst1dDesc,
|
||||
decltype(transform_tensor_descriptor(ref_dstDesc,
|
||||
make_tuple(make_pad_transform(ref_invariantLen, 0, 2)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{})));
|
||||
|
||||
using refType_src2dDesc = decltype(ref_src2dDesc);
|
||||
using refType_dst1dDesc = decltype(ref_dst1dDesc);
|
||||
using refType_dst1dDesc = decltype(ref_dstDesc);
|
||||
};
|
||||
|
||||
using refType_src2dDesc =
|
||||
typename get_ref_desc_types<srcDims, dstDims, toReduceDims>::refType_src2dDesc;
|
||||
using refType_dst1dDesc =
|
||||
typename get_ref_desc_types<srcDims, dstDims, toReduceDims>::refType_dst1dDesc;
|
||||
using refType_src2dDesc = typename get_ref_desc_types<srcDims>::refType_src2dDesc;
|
||||
using refType_dst1dDesc = typename get_ref_desc_types<srcDims>::refType_dst1dDesc;
|
||||
using refType_src2dDesc_padded_12 =
|
||||
typename get_ref_desc_types<srcDims, dstDims, toReduceDims>::refType_src2dDesc_padded_12;
|
||||
using refType_dst1dDesc_padded =
|
||||
typename get_ref_desc_types<srcDims, dstDims, toReduceDims>::refType_dst1dDesc_padded;
|
||||
typename get_ref_desc_types<srcDims>::refType_src2dDesc_padded_12;
|
||||
using refType_dst1dDesc_padded = typename get_ref_desc_types<srcDims>::refType_dst1dDesc_padded;
|
||||
|
||||
template <bool need_padding>
|
||||
static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc)
|
||||
@@ -290,15 +244,15 @@ extern "C" __global__ void gridwise_generic_reduce_1(int origReduceLen,
|
||||
const void* __restrict__ p_src_global,
|
||||
float beta,
|
||||
void* __restrict__ p_dst_global,
|
||||
void* __restrict__ ws_global,
|
||||
const void CONSTANT* ws_global,
|
||||
long ws_buf2_bytes_offset,
|
||||
void* __restrict__ indices_global)
|
||||
{
|
||||
(void)BlkGroupSize;
|
||||
(void)ws_buf2_bytes_offset;
|
||||
|
||||
const void* p_src2dDesc = ws_global;
|
||||
const void* p_dst1dDesc = static_cast<char*>(ws_global) + 2048;
|
||||
const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global);
|
||||
const void* p_dst1dDesc = static_cast<const char*>(p_src2dDesc) + 2048;
|
||||
|
||||
const auto src2dDesc = get_reduction_src2d_descriptor<src2d_need_padding>(p_src2dDesc);
|
||||
const auto dst1dDesc = get_reduction_dst1d_descriptor<dst1d_need_padding>(p_dst1dDesc);
|
||||
|
||||
@@ -45,8 +45,11 @@ constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable
|
||||
constexpr index_t srcDims = CK_PARAM_IN_DIMS;
|
||||
constexpr index_t dstDims = CK_PARAM_OUT_DIMS;
|
||||
|
||||
using toReduceDims = Sequence<CK_PARAM_TOREDUCE_DIMS>;
|
||||
using invariantDims = Sequence<CK_PARAM_INVARIANT_DIMS>;
|
||||
constexpr index_t num_toReduceDims = CK_PARAM_NUM_TOREDUCE_DIMS;
|
||||
constexpr index_t num_invariantDims = srcDims - num_toReduceDims;
|
||||
|
||||
using invariantDims = typename arithmetic_sequence_gen<0, num_invariantDims, 1>::type;
|
||||
using toReduceDims = typename arithmetic_sequence_gen<num_invariantDims, srcDims, 1>::type;
|
||||
|
||||
constexpr ReduceTensorOp_t op = static_cast<ReduceTensorOp_t>(CK_PARAM_REDUCE_OP);
|
||||
constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0
|
||||
@@ -59,15 +62,7 @@ constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0
|
||||
constexpr bool src2d_need_padding = static_cast<bool>(CK_PARAM_SRC2D_PADDING);
|
||||
constexpr bool dst1d_need_padding = static_cast<bool>(CK_PARAM_DST1D_PADDING);
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////
|
||||
using specDims = typename sequence_merge<invariantDims, toReduceDims>::type;
|
||||
|
||||
static_assert(is_valid_sequence_map<specDims>::value && specDims::Size() == srcDims,
|
||||
"Wrong invariant and/or toReduce dimensions!");
|
||||
|
||||
// The number of invariant dimensions can be zero if all dimension are to be reduced
|
||||
static_assert(invariantDims::Size() > 0 || dstDims == 1,
|
||||
"If all source dimensions are reduced, the dest should have only one dimension !!");
|
||||
static_assert(num_invariantDims > 0, "Not all dimensins are reduced for this kernel !!");
|
||||
|
||||
constexpr bool indexable = reduce_binary_operator<compType, op>::indexable;
|
||||
constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES);
|
||||
@@ -111,12 +106,6 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize,
|
||||
int inStride3,
|
||||
int inStride4,
|
||||
int inStride5,
|
||||
int outLength0,
|
||||
int outLength1,
|
||||
int outLength2,
|
||||
int outLength3,
|
||||
int outLength4,
|
||||
int outLength5,
|
||||
int outStride0,
|
||||
int outStride1,
|
||||
int outStride2,
|
||||
@@ -132,14 +121,12 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize,
|
||||
|
||||
const int srcLengths[6] = {inLength0, inLength1, inLength2, inLength3, inLength4, inLength5};
|
||||
const int srcStrides[6] = {inStride0, inStride1, inStride2, inStride3, inStride4, inStride5};
|
||||
const int dstLengths[6] = {
|
||||
outLength0, outLength1, outLength2, outLength3, outLength4, outLength5};
|
||||
const int dstStrides[6] = {
|
||||
outStride0, outStride1, outStride2, outStride3, outStride4, outStride5};
|
||||
|
||||
const auto tupleSrcLengths = make_tuple_from_array(srcLengths, Number<srcDims>{});
|
||||
const auto tupleSrcStrides = make_tuple_from_array(srcStrides, Number<srcDims>{});
|
||||
const auto tupleDstLengths = make_tuple_from_array(dstLengths, Number<dstDims>{});
|
||||
const auto tupleDstLengths = make_tuple_from_array(srcLengths, Number<dstDims>{});
|
||||
const auto tupleDstStrides = make_tuple_from_array(dstStrides, Number<dstDims>{});
|
||||
|
||||
const auto srcDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
|
||||
@@ -178,12 +165,12 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize,
|
||||
make_pad_transform(toReduceLen, 0, srcPad2)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
if(hipThreadIdx_x == 0)
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(src2dDesc_2)*>(p_src2dDesc) = src2dDesc_2;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(hipThreadIdx_x == 0)
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(src2dDesc)*>(p_src2dDesc) = src2dDesc;
|
||||
}
|
||||
|
||||
@@ -195,12 +182,12 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize,
|
||||
make_tuple(make_pad_transform(invariantLen, 0, dstPad)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
if(hipThreadIdx_x == 0)
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(dst1dDesc_2)*>(p_dst1dDesc) = dst1dDesc_2;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(hipThreadIdx_x == 0)
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(dst1dDesc)*>(p_dst1dDesc) = dst1dDesc;
|
||||
}
|
||||
};
|
||||
@@ -291,15 +278,15 @@ extern "C" __global__ void gridwise_generic_reduce_1(int origReduceLen,
|
||||
const void* __restrict__ p_src_global,
|
||||
float beta,
|
||||
void* __restrict__ p_dst_global,
|
||||
void* __restrict__ ws_global,
|
||||
const void CONSTANT* ws_global,
|
||||
long ws_buf2_bytes_offset,
|
||||
void* __restrict__ indices_global)
|
||||
{
|
||||
(void)BlkGroupSize;
|
||||
(void)ws_buf2_bytes_offset;
|
||||
|
||||
const void* p_src2dDesc = ws_global;
|
||||
const void* p_dst1dDesc = static_cast<char*>(ws_global) + 2048;
|
||||
const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global);
|
||||
const void* p_dst1dDesc = static_cast<const char*>(p_src2dDesc) + 2048;
|
||||
|
||||
const auto src2dDesc = get_reduction_src2d_descriptor<src2d_need_padding>(p_src2dDesc);
|
||||
const auto dst1dDesc = get_reduction_dst1d_descriptor<dst1d_need_padding>(p_dst1dDesc);
|
||||
|
||||
@@ -43,9 +43,6 @@ using compType =
|
||||
constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable
|
||||
|
||||
constexpr index_t srcDims = CK_PARAM_IN_DIMS;
|
||||
constexpr index_t dstDims = CK_PARAM_OUT_DIMS;
|
||||
|
||||
using toReduceDims = Sequence<CK_PARAM_TOREDUCE_DIMS>;
|
||||
|
||||
constexpr ReduceTensorOp_t op = static_cast<ReduceTensorOp_t>(CK_PARAM_REDUCE_OP);
|
||||
constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0
|
||||
@@ -58,16 +55,6 @@ constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0
|
||||
constexpr bool src2d_need_padding = static_cast<bool>(CK_PARAM_SRC2D_PADDING);
|
||||
constexpr bool dst1d_need_padding = static_cast<bool>(CK_PARAM_DST1D_PADDING);
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////
|
||||
using specDims = typename sequence_merge<Sequence<>, toReduceDims>::type;
|
||||
|
||||
static_assert(is_valid_sequence_map<specDims>::value && specDims::Size() == srcDims,
|
||||
"Wrong invariant and/or toReduce dimensions!");
|
||||
|
||||
// The number of invariant dimensions can be zero if all dimension are to be reduced
|
||||
static_assert(dstDims == 1,
|
||||
"If all source dimensions are reduced, the dest should have only one dimension !!");
|
||||
|
||||
constexpr bool indexable = reduce_binary_operator<compType, op>::indexable;
|
||||
constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES);
|
||||
|
||||
@@ -110,18 +97,6 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize,
|
||||
int inStride3,
|
||||
int inStride4,
|
||||
int inStride5,
|
||||
int outLength0,
|
||||
int outLength1,
|
||||
int outLength2,
|
||||
int outLength3,
|
||||
int outLength4,
|
||||
int outLength5,
|
||||
int outStride0,
|
||||
int outStride1,
|
||||
int outStride2,
|
||||
int outStride3,
|
||||
int outStride4,
|
||||
int outStride5,
|
||||
void* __restrict__ ws_global)
|
||||
{
|
||||
(void)BlkGroupSize;
|
||||
@@ -131,18 +106,14 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize,
|
||||
|
||||
const int srcLengths[6] = {inLength0, inLength1, inLength2, inLength3, inLength4, inLength5};
|
||||
const int srcStrides[6] = {inStride0, inStride1, inStride2, inStride3, inStride4, inStride5};
|
||||
const int dstLengths[6] = {
|
||||
outLength0, outLength1, outLength2, outLength3, outLength4, outLength5};
|
||||
const int dstStrides[6] = {
|
||||
outStride0, outStride1, outStride2, outStride3, outStride4, outStride5};
|
||||
|
||||
const auto tupleSrcLengths = make_tuple_from_array(srcLengths, Number<srcDims>{});
|
||||
const auto tupleSrcStrides = make_tuple_from_array(srcStrides, Number<srcDims>{});
|
||||
const auto tupleDstLengths = make_tuple_from_array(dstLengths, Number<dstDims>{});
|
||||
const auto tupleDstStrides = make_tuple_from_array(dstStrides, Number<dstDims>{});
|
||||
const auto tupleDstLengths = make_tuple(1);
|
||||
const auto tupleDstStrides = make_tuple(1);
|
||||
|
||||
const auto srcDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
|
||||
const auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
|
||||
auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
|
||||
|
||||
const auto one_dim_srcDesc = transform_tensor_descriptor(
|
||||
srcDesc,
|
||||
@@ -156,14 +127,8 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize,
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 1>{}));
|
||||
|
||||
auto dst1dDesc = transform_tensor_descriptor(
|
||||
dstDesc,
|
||||
make_tuple(make_merge_transform(tupleDstLengths)),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto invariantLen = src2dDesc.GetLength(Number<0>{});
|
||||
const auto toReduceLen = src2dDesc.GetLength(Number<1>{});
|
||||
constexpr int invariantLen = 1;
|
||||
const auto toReduceLen = src2dDesc.GetLength(Number<1>{});
|
||||
|
||||
constexpr auto copySliceLen = warpSize * GredAccessesPerThreadInWarp;
|
||||
|
||||
@@ -179,12 +144,12 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize,
|
||||
make_pad_transform(toReduceLen, 0, srcPad2)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
if(hipThreadIdx_x == 0)
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(src2dDesc_2)*>(p_src2dDesc) = src2dDesc_2;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(hipThreadIdx_x == 0)
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(src2dDesc)*>(p_src2dDesc) = src2dDesc;
|
||||
}
|
||||
|
||||
@@ -192,31 +157,29 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize,
|
||||
{
|
||||
const auto dstPad = GridSize * BlockSize / warpSize - invariantLen;
|
||||
auto dst1dDesc_2 =
|
||||
transform_tensor_descriptor(dst1dDesc,
|
||||
transform_tensor_descriptor(dstDesc,
|
||||
make_tuple(make_pad_transform(invariantLen, 0, dstPad)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
if(hipThreadIdx_x == 0)
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(dst1dDesc_2)*>(p_dst1dDesc) = dst1dDesc_2;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(hipThreadIdx_x == 0)
|
||||
*static_cast<decltype(dst1dDesc)*>(p_dst1dDesc) = dst1dDesc;
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(dstDesc)*>(p_dst1dDesc) = dstDesc;
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t srcDims, index_t dstDims, typename toReduceDims>
|
||||
template <index_t srcDims>
|
||||
struct get_ref_desc_types
|
||||
{
|
||||
static constexpr auto ref_srcLengths = typename uniform_sequence_gen<srcDims, 8>::type{};
|
||||
static constexpr auto ref_dstLengths = typename uniform_sequence_gen<dstDims, 1>::type{};
|
||||
|
||||
// don't have to use accurate strides to get an expected referrence type
|
||||
static constexpr auto ref_srcDesc = make_naive_tensor_descriptor(
|
||||
make_tuple_from_seq(ref_srcLengths), make_tuple_from_seq(ref_srcLengths));
|
||||
static constexpr auto ref_dstDesc = make_naive_tensor_descriptor(
|
||||
make_tuple_from_seq(ref_dstLengths), make_tuple_from_seq(ref_dstLengths));
|
||||
static constexpr auto ref_dstDesc = make_naive_tensor_descriptor(make_tuple(1), make_tuple(1));
|
||||
|
||||
static constexpr auto ref_one_dim_srcDesc = transform_tensor_descriptor(
|
||||
ref_srcDesc,
|
||||
@@ -231,12 +194,6 @@ struct get_ref_desc_types
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 1>{}));
|
||||
|
||||
static constexpr auto ref_dst1dDesc = transform_tensor_descriptor(
|
||||
ref_dstDesc,
|
||||
make_tuple(make_merge_transform(make_tuple_from_seq(ref_dstLengths))),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
static constexpr auto ref_invariantLen = ref_src2dDesc.GetLength(Number<0>{});
|
||||
static constexpr auto ref_toReduceLen = ref_src2dDesc.GetLength(Number<1>{});
|
||||
|
||||
@@ -249,23 +206,19 @@ struct get_ref_desc_types
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{})));
|
||||
|
||||
using refType_dst1dDesc_padded =
|
||||
decltype(transform_tensor_descriptor(ref_dst1dDesc,
|
||||
decltype(transform_tensor_descriptor(ref_dstDesc,
|
||||
make_tuple(make_pad_transform(ref_invariantLen, 0, 2)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{})));
|
||||
|
||||
using refType_src2dDesc = decltype(ref_src2dDesc);
|
||||
using refType_dst1dDesc = decltype(ref_dst1dDesc);
|
||||
using refType_dst1dDesc = decltype(ref_dstDesc);
|
||||
};
|
||||
|
||||
using refType_src2dDesc =
|
||||
typename get_ref_desc_types<srcDims, dstDims, toReduceDims>::refType_src2dDesc;
|
||||
using refType_dst1dDesc =
|
||||
typename get_ref_desc_types<srcDims, dstDims, toReduceDims>::refType_dst1dDesc;
|
||||
using refType_src2dDesc_padded_12
|
||||
typename get_ref_desc_types<srcDims, dstDims, toReduceDims>::refType_src2dDesc_padded_12;
|
||||
using refType_dst1dDesc_padded =
|
||||
typename get_ref_desc_types<srcDims, dstDims, toReduceDims>::refType_dst1dDesc_padded;
|
||||
using refType_src2dDesc = typename get_ref_desc_types<srcDims>::refType_src2dDesc;
|
||||
using refType_dst1dDesc = typename get_ref_desc_types<srcDims>::refType_dst1dDesc;
|
||||
using refType_src2dDesc_padded_12 typename get_ref_desc_types<srcDims>::refType_src2dDesc_padded_12;
|
||||
using refType_dst1dDesc_padded = typename get_ref_desc_types<srcDims>::refType_dst1dDesc_padded;
|
||||
|
||||
template <bool need_padding>
|
||||
static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc)
|
||||
@@ -291,15 +244,15 @@ extern "C" __global__ void gridwise_generic_reduce_1(int origReduceLen,
|
||||
const void* __restrict__ p_src_global,
|
||||
float beta,
|
||||
void* __restrict__ p_dst_global,
|
||||
void* __restrict__ ws_global,
|
||||
const void CONSTANT* ws_global,
|
||||
long ws_buf2_bytes_offset,
|
||||
void* __restrict__ indices_global)
|
||||
{
|
||||
(void)BlkGroupSize;
|
||||
(void)ws_buf2_bytes_offset;
|
||||
|
||||
const void* p_src2dDesc = ws_global;
|
||||
const void* p_dst1dDesc = static_cast<char*>(ws_global) + 2048;
|
||||
const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global);
|
||||
const void* p_dst1dDesc = static_cast<const char*>(p_src2dDesc) + 2048;
|
||||
|
||||
const auto src2dDesc = get_reduction_src2d_descriptor<src2d_need_padding>(p_src2dDesc);
|
||||
const auto dst1dDesc = get_reduction_dst1d_descriptor<dst1d_need_padding>(p_dst1dDesc);
|
||||
|
||||
@@ -45,8 +45,11 @@ constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable
|
||||
constexpr index_t srcDims = CK_PARAM_IN_DIMS;
|
||||
constexpr index_t dstDims = CK_PARAM_OUT_DIMS;
|
||||
|
||||
using toReduceDims = Sequence<CK_PARAM_TOREDUCE_DIMS>;
|
||||
using invariantDims = Sequence<CK_PARAM_INVARIANT_DIMS>;
|
||||
constexpr index_t num_toReduceDims = CK_PARAM_NUM_TOREDUCE_DIMS;
|
||||
constexpr index_t num_invariantDims = srcDims - num_toReduceDims;
|
||||
|
||||
using invariantDims = typename arithmetic_sequence_gen<0, num_invariantDims, 1>::type;
|
||||
using toReduceDims = typename arithmetic_sequence_gen<num_invariantDims, srcDims, 1>::type;
|
||||
|
||||
constexpr ReduceTensorOp_t op = static_cast<ReduceTensorOp_t>(CK_PARAM_REDUCE_OP);
|
||||
constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0
|
||||
@@ -59,15 +62,7 @@ constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0
|
||||
constexpr bool src2d_need_padding = static_cast<bool>(CK_PARAM_SRC2D_PADDING);
|
||||
constexpr bool dst1d_need_padding = static_cast<bool>(CK_PARAM_DST1D_PADDING);
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////
|
||||
using specDims = typename sequence_merge<invariantDims, toReduceDims>::type;
|
||||
|
||||
static_assert(is_valid_sequence_map<specDims>::value && specDims::Size() == srcDims,
|
||||
"Wrong invariant and/or toReduce dimensions!");
|
||||
|
||||
// The number of invariant dimensions can be zero if all dimension are to be reduced
|
||||
static_assert(invariantDims::Size() > 0 || dstDims == 1,
|
||||
"If all source dimensions are reduced, the dest should have only one dimension !!");
|
||||
static_assert(num_invariantDims > 0, "Not all dimensins are reduced for this kernel !!");
|
||||
|
||||
constexpr bool indexable = reduce_binary_operator<compType, op>::indexable;
|
||||
constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES);
|
||||
@@ -111,12 +106,6 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize,
|
||||
int inStride3,
|
||||
int inStride4,
|
||||
int inStride5,
|
||||
int outLength0,
|
||||
int outLength1,
|
||||
int outLength2,
|
||||
int outLength3,
|
||||
int outLength4,
|
||||
int outLength5,
|
||||
int outStride0,
|
||||
int outStride1,
|
||||
int outStride2,
|
||||
@@ -132,14 +121,12 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize,
|
||||
|
||||
const int srcLengths[6] = {inLength0, inLength1, inLength2, inLength3, inLength4, inLength5};
|
||||
const int srcStrides[6] = {inStride0, inStride1, inStride2, inStride3, inStride4, inStride5};
|
||||
const int dstLengths[6] = {
|
||||
outLength0, outLength1, outLength2, outLength3, outLength4, outLength5};
|
||||
const int dstStrides[6] = {
|
||||
outStride0, outStride1, outStride2, outStride3, outStride4, outStride5};
|
||||
|
||||
const auto tupleSrcLengths = make_tuple_from_array(srcLengths, Number<srcDims>{});
|
||||
const auto tupleSrcStrides = make_tuple_from_array(srcStrides, Number<srcDims>{});
|
||||
const auto tupleDstLengths = make_tuple_from_array(dstLengths, Number<dstDims>{});
|
||||
const auto tupleDstLengths = make_tuple_from_array(srcLengths, Number<dstDims>{});
|
||||
const auto tupleDstStrides = make_tuple_from_array(dstStrides, Number<dstDims>{});
|
||||
|
||||
const auto srcDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
|
||||
@@ -179,12 +166,12 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize,
|
||||
make_pad_transform(toReduceLen, 0, srcPad2)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
if(hipThreadIdx_x == 0)
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(src2dDesc_2)*>(p_src2dDesc) = src2dDesc_2;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(hipThreadIdx_x == 0)
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(src2dDesc)*>(p_src2dDesc) = src2dDesc;
|
||||
}
|
||||
|
||||
@@ -196,12 +183,12 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize,
|
||||
make_tuple(make_pad_transform(invariantLen, 0, dstPad)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
if(hipThreadIdx_x == 0)
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(dst1dDesc_2)*>(p_dst1dDesc) = dst1dDesc_2;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(hipThreadIdx_x == 0)
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(dst1dDesc)*>(p_dst1dDesc) = dst1dDesc;
|
||||
}
|
||||
};
|
||||
@@ -292,15 +279,15 @@ extern "C" __global__ void gridwise_generic_reduce_1(int origReduceLen,
|
||||
const void* __restrict__ p_src_global,
|
||||
float beta,
|
||||
void* __restrict__ p_dst_global,
|
||||
void* __restrict__ ws_global,
|
||||
const void CONSTANT* ws_global,
|
||||
long ws_buf2_bytes_offset,
|
||||
void* __restrict__ indices_global)
|
||||
{
|
||||
(void)BlkGroupSize;
|
||||
(void)ws_buf2_bytes_offset;
|
||||
|
||||
const void* p_src2dDesc = ws_global;
|
||||
const void* p_dst1dDesc = static_cast<char*>(ws_global) + 2048;
|
||||
const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global);
|
||||
const void* p_dst1dDesc = static_cast<const char*>(p_src2dDesc) + 2048;
|
||||
|
||||
const auto src2dDesc = get_reduction_src2d_descriptor<src2d_need_padding>(p_src2dDesc);
|
||||
const auto dst1dDesc = get_reduction_dst1d_descriptor<dst1d_need_padding>(p_dst1dDesc);
|
||||
|
||||
@@ -0,0 +1,205 @@
|
||||
/*******************************************************************************
|
||||
*
|
||||
* MIT License
|
||||
*
|
||||
* Copyright (c) 2021 Advanced Micro Devices, Inc.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in all
|
||||
* copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
* SOFTWARE.
|
||||
*
|
||||
*******************************************************************************/
|
||||
#include "config.hpp"
|
||||
#include "number.hpp"
|
||||
#include "sequence.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "data_type_enum_helper.hpp"
|
||||
#include "reduction_common.hpp"
|
||||
#include "gridwise_generic_2d_reduction_blockwise.hpp"
|
||||
|
||||
using namespace ck;
|
||||
|
||||
using srcDataType =
|
||||
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_SRC_DATATYPE)>::type;
|
||||
using dstDataType =
|
||||
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_DST_DATATYPE)>::type;
|
||||
using compType =
|
||||
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_REDUCE_COMPTYPE)>::type;
|
||||
|
||||
constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable
|
||||
|
||||
constexpr ReduceTensorOp_t op = static_cast<ReduceTensorOp_t>(CK_PARAM_REDUCE_OP);
|
||||
constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0
|
||||
? NanPropagation_t::NOT_PROPAGATE_NAN
|
||||
: NanPropagation_t::PROPAGATE_NAN;
|
||||
constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0
|
||||
? ReduceTensorIndices_t::NO_INDICES
|
||||
: ReduceTensorIndices_t::FLATTENED_INDICES;
|
||||
|
||||
constexpr bool src2d_need_padding = static_cast<bool>(CK_PARAM_SRC2D_PADDING);
|
||||
constexpr bool dst1d_need_padding = static_cast<bool>(CK_PARAM_DST1D_PADDING);
|
||||
|
||||
constexpr bool indexable = reduce_binary_operator<compType, op>::indexable;
|
||||
constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES);
|
||||
|
||||
constexpr index_t GredAccessesPerThreadInBlock = CK_PARAM_ACCESSES_PER_THREAD_INBLOCK; // tunable
|
||||
|
||||
extern "C" __global__ void
|
||||
gridwise_generic_reduce_2_prepare(int GridSize, int BlkGroupSize, void* __restrict__ ws_global)
|
||||
{
|
||||
(void)GridSize;
|
||||
|
||||
void* p_src2dDesc = ws_global;
|
||||
void* p_dst1dDesc = static_cast<char*>(ws_global) + 2048;
|
||||
|
||||
const auto tupleDstLengths = make_tuple(1);
|
||||
const auto tupleDstStrides = make_tuple(1);
|
||||
|
||||
auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
|
||||
|
||||
const index_t invariantLen = dstDesc.GetLength(Number<0>{});
|
||||
const index_t toReduceLen = BlkGroupSize;
|
||||
|
||||
auto src2dDesc = make_naive_tensor_descriptor_packed(make_tuple(invariantLen, toReduceLen));
|
||||
|
||||
constexpr auto copySliceLen = BlockSize * GredAccessesPerThreadInBlock;
|
||||
|
||||
if constexpr(src2d_need_padding)
|
||||
{
|
||||
const auto srcPad =
|
||||
((toReduceLen + copySliceLen - 1) / copySliceLen) * copySliceLen - toReduceLen;
|
||||
|
||||
auto src2dDesc_2 =
|
||||
transform_tensor_descriptor(src2dDesc,
|
||||
make_tuple(make_pass_through_transform(invariantLen),
|
||||
make_pad_transform(toReduceLen, 0, srcPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(src2dDesc_2)*>(p_src2dDesc) = src2dDesc_2;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(src2dDesc)*>(p_src2dDesc) = src2dDesc;
|
||||
}
|
||||
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(dstDesc)*>(p_dst1dDesc) = dstDesc;
|
||||
};
|
||||
|
||||
struct get_ref_desc_types
|
||||
{
|
||||
static constexpr auto ref_tupleDstLengths = make_tuple(8);
|
||||
static constexpr auto ref_dstDesc =
|
||||
make_naive_tensor_descriptor(ref_tupleDstLengths, ref_tupleDstLengths);
|
||||
|
||||
static constexpr index_t ref_invariantLen = ref_dstDesc.GetLength(Number<0>{});
|
||||
static constexpr index_t ref_toReduceLen = 8;
|
||||
|
||||
static constexpr auto ref_src2dDesc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(ref_invariantLen, ref_toReduceLen));
|
||||
|
||||
using refType_src2dDesc = decltype(ref_src2dDesc);
|
||||
using refType_dst1dDesc = decltype(ref_dstDesc);
|
||||
|
||||
// used by the BlockWise and MultiBlock method
|
||||
using refType_src2dDesc_padded_34 = decltype(
|
||||
transform_tensor_descriptor(ref_src2dDesc,
|
||||
make_tuple(make_pass_through_transform(ref_invariantLen),
|
||||
make_pad_transform(ref_toReduceLen, 0, 2)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{})));
|
||||
|
||||
using refType_dst1dDesc_padded =
|
||||
decltype(transform_tensor_descriptor(ref_dstDesc,
|
||||
make_tuple(make_pad_transform(ref_invariantLen, 0, 2)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{})));
|
||||
};
|
||||
|
||||
using refType_src2dDesc = typename get_ref_desc_types::refType_src2dDesc;
|
||||
using refType_dst1dDesc = typename get_ref_desc_types::refType_dst1dDesc;
|
||||
using refType_src2dDesc_padded_34 = typename get_ref_desc_types::refType_src2dDesc_padded_34;
|
||||
using refType_dst1dDesc_padded = typename get_ref_desc_types::refType_dst1dDesc_padded;
|
||||
|
||||
template <bool need_padding>
|
||||
static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc)
|
||||
{
|
||||
if constexpr(need_padding)
|
||||
return (*reinterpret_cast<const refType_src2dDesc_padded_34*>(p_src2dDesc));
|
||||
else
|
||||
return (*reinterpret_cast<const refType_src2dDesc*>(p_src2dDesc));
|
||||
};
|
||||
|
||||
template <bool need_padding>
|
||||
static __device__ auto get_reduction_dst1d_descriptor(const void* p_dst1dDesc)
|
||||
{
|
||||
if constexpr(need_padding)
|
||||
return (*reinterpret_cast<const refType_dst1dDesc_padded*>(p_dst1dDesc));
|
||||
else
|
||||
return (*reinterpret_cast<const refType_dst1dDesc*>(p_dst1dDesc));
|
||||
};
|
||||
|
||||
extern "C" __global__ void gridwise_generic_reduce_2(int origReduceLen,
|
||||
float alpha,
|
||||
const void* __restrict__ p_src_global,
|
||||
float beta,
|
||||
void* __restrict__ p_dst_global,
|
||||
const void CONSTANT* ws_global,
|
||||
long ws_buf2_bytes_offset,
|
||||
void* __restrict__ indices_global)
|
||||
{
|
||||
(void)p_src_global;
|
||||
|
||||
const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global);
|
||||
const void* p_dst1dDesc = static_cast<const char*>(p_src2dDesc) + 2048;
|
||||
void* ws_buf1_global = const_cast<char*>(static_cast<const char*>(p_src2dDesc) + 4096);
|
||||
|
||||
const auto src2dDesc = get_reduction_src2d_descriptor<src2d_need_padding>(p_src2dDesc);
|
||||
const auto dst1dDesc = get_reduction_dst1d_descriptor<dst1d_need_padding>(p_dst1dDesc);
|
||||
|
||||
using gridwise_2d_reduce = GridwiseReduction_xy_to_x_blockwise<BlockSize,
|
||||
srcDataType,
|
||||
dstDataType,
|
||||
compType,
|
||||
decltype(src2dDesc),
|
||||
decltype(dst1dDesc),
|
||||
op,
|
||||
nanPropaOpt,
|
||||
reduceIndicesOpt,
|
||||
false,
|
||||
true,
|
||||
GredAccessesPerThreadInBlock>;
|
||||
|
||||
void* const ws_buf2_global =
|
||||
ws_buf2_bytes_offset > 0
|
||||
? static_cast<void*>(static_cast<char*>(ws_buf1_global) + ws_buf2_bytes_offset)
|
||||
: nullptr;
|
||||
|
||||
constexpr int RunId = need_indices ? 3 : 1;
|
||||
gridwise_2d_reduce::template Run<RunId>(
|
||||
src2dDesc,
|
||||
dst1dDesc,
|
||||
origReduceLen,
|
||||
alpha,
|
||||
static_cast<const srcDataType* const __restrict__>(ws_buf1_global),
|
||||
beta,
|
||||
static_cast<dstDataType* const __restrict__>(p_dst_global),
|
||||
static_cast<const int* const __restrict__>(ws_buf2_global),
|
||||
static_cast<int* const __restrict__>(indices_global));
|
||||
};
|
||||
@@ -42,12 +42,8 @@ using compType =
|
||||
|
||||
constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable
|
||||
|
||||
constexpr index_t srcDims = CK_PARAM_IN_DIMS;
|
||||
constexpr index_t dstDims = CK_PARAM_OUT_DIMS;
|
||||
|
||||
using toReduceDims = Sequence<CK_PARAM_TOREDUCE_DIMS>;
|
||||
using invariantDims = Sequence<CK_PARAM_INVARIANT_DIMS>; // this could be empty
|
||||
|
||||
constexpr ReduceTensorOp_t op = static_cast<ReduceTensorOp_t>(CK_PARAM_REDUCE_OP);
|
||||
constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0
|
||||
? NanPropagation_t::NOT_PROPAGATE_NAN
|
||||
@@ -59,16 +55,6 @@ constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0
|
||||
constexpr bool src2d_need_padding = static_cast<bool>(CK_PARAM_SRC2D_PADDING);
|
||||
constexpr bool dst1d_need_padding = static_cast<bool>(CK_PARAM_DST1D_PADDING);
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////
|
||||
using specDims = typename sequence_merge<invariantDims, toReduceDims>::type;
|
||||
|
||||
static_assert(is_valid_sequence_map<specDims>::value && specDims::Size() == srcDims,
|
||||
"Wrong invariant and/or toReduce dimensions!");
|
||||
|
||||
// The number of invariant dimensions can be zero if all dimension are to be reduced
|
||||
static_assert(invariantDims::Size() > 0 || dstDims == 1,
|
||||
"If all source dimensions are reduced, the dest should have only one dimension !!");
|
||||
|
||||
constexpr bool indexable = reduce_binary_operator<compType, op>::indexable;
|
||||
constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES);
|
||||
|
||||
@@ -152,20 +138,20 @@ extern "C" __global__ void gridwise_generic_reduce_2_prepare(int GridSize,
|
||||
make_pad_transform(toReduceLen, 0, srcPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
if(hipThreadIdx_x == 0)
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(src2dDesc_2)*>(p_src2dDesc) = src2dDesc_2;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(hipThreadIdx_x == 0)
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(src2dDesc)*>(p_src2dDesc) = src2dDesc;
|
||||
}
|
||||
|
||||
if(hipThreadIdx_x == 0)
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(dst1dDesc)*>(p_dst1dDesc) = dst1dDesc;
|
||||
};
|
||||
|
||||
template <index_t srcDims, index_t dstDims, typename invariantDims, typename toReduceDims>
|
||||
template <index_t dstDims>
|
||||
struct get_ref_desc_types
|
||||
{
|
||||
static constexpr auto ref_tupleDstLengths =
|
||||
@@ -203,16 +189,11 @@ struct get_ref_desc_types
|
||||
make_tuple(Sequence<0>{})));
|
||||
};
|
||||
|
||||
using refType_src2dDesc =
|
||||
typename get_ref_desc_types<srcDims, dstDims, invariantDims, toReduceDims>::refType_src2dDesc;
|
||||
using refType_dst1dDesc =
|
||||
typename get_ref_desc_types<srcDims, dstDims, invariantDims, toReduceDims>::refType_dst1dDesc;
|
||||
using refType_src2dDesc = typename get_ref_desc_types<dstDims>::refType_src2dDesc;
|
||||
using refType_dst1dDesc = typename get_ref_desc_types<dstDims>::refType_dst1dDesc;
|
||||
using refType_src2dDesc_padded_34 =
|
||||
typename get_ref_desc_types<srcDims, dstDims, invariantDims, toReduceDims>::
|
||||
refType_src2dDesc_padded_34;
|
||||
using refType_dst1dDesc_padded =
|
||||
typename get_ref_desc_types<srcDims, dstDims, invariantDims, toReduceDims>::
|
||||
refType_dst1dDesc_padded;
|
||||
typename get_ref_desc_types<dstDims>::refType_src2dDesc_padded_34;
|
||||
using refType_dst1dDesc_padded = typename get_ref_desc_types<dstDims>::refType_dst1dDesc_padded;
|
||||
|
||||
template <bool need_padding>
|
||||
static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc)
|
||||
@@ -237,15 +218,15 @@ extern "C" __global__ void gridwise_generic_reduce_2(int origReduceLen,
|
||||
const void* __restrict__ p_src_global,
|
||||
float beta,
|
||||
void* __restrict__ p_dst_global,
|
||||
void* __restrict__ ws_global,
|
||||
const void CONSTANT* ws_global,
|
||||
long ws_buf2_bytes_offset,
|
||||
void* __restrict__ indices_global)
|
||||
{
|
||||
(void)p_src_global;
|
||||
|
||||
const void* p_src2dDesc = ws_global;
|
||||
const void* p_dst1dDesc = static_cast<char*>(ws_global) + 2048;
|
||||
void* ws_buf1_global = static_cast<char*>(ws_global) + 4096;
|
||||
const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global);
|
||||
const void* p_dst1dDesc = static_cast<const char*>(p_src2dDesc) + 2048;
|
||||
void* ws_buf1_global = const_cast<char*>(static_cast<const char*>(p_src2dDesc) + 4096);
|
||||
|
||||
const auto src2dDesc = get_reduction_src2d_descriptor<src2d_need_padding>(p_src2dDesc);
|
||||
const auto dst1dDesc = get_reduction_dst1d_descriptor<dst1d_need_padding>(p_dst1dDesc);
|
||||
@@ -0,0 +1,222 @@
|
||||
/*******************************************************************************
|
||||
*
|
||||
* MIT License
|
||||
*
|
||||
* Copyright (c) 2021 Advanced Micro Devices, Inc.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in all
|
||||
* copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
* SOFTWARE.
|
||||
*
|
||||
*******************************************************************************/
|
||||
#include "config.hpp"
|
||||
#include "number.hpp"
|
||||
#include "sequence.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "data_type_enum_helper.hpp"
|
||||
#include "reduction_common.hpp"
|
||||
#include "gridwise_generic_2d_reduction_direct_threadwise.hpp"
|
||||
|
||||
using namespace ck;
|
||||
|
||||
using srcDataType =
|
||||
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_SRC_DATATYPE)>::type;
|
||||
using dstDataType =
|
||||
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_DST_DATATYPE)>::type;
|
||||
using compType =
|
||||
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_REDUCE_COMPTYPE)>::type;
|
||||
|
||||
constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable
|
||||
|
||||
using toReduceDims = Sequence<CK_PARAM_TOREDUCE_DIMS>;
|
||||
using invariantDims = Sequence<CK_PARAM_INVARIANT_DIMS>; // this could be empty
|
||||
|
||||
constexpr ReduceTensorOp_t op = static_cast<ReduceTensorOp_t>(CK_PARAM_REDUCE_OP);
|
||||
constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0
|
||||
? NanPropagation_t::NOT_PROPAGATE_NAN
|
||||
: NanPropagation_t::PROPAGATE_NAN;
|
||||
constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0
|
||||
? ReduceTensorIndices_t::NO_INDICES
|
||||
: ReduceTensorIndices_t::FLATTENED_INDICES;
|
||||
|
||||
constexpr bool src2d_need_padding = static_cast<bool>(CK_PARAM_SRC2D_PADDING);
|
||||
constexpr bool dst1d_need_padding = static_cast<bool>(CK_PARAM_DST1D_PADDING);
|
||||
|
||||
constexpr bool indexable = reduce_binary_operator<compType, op>::indexable;
|
||||
constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES);
|
||||
|
||||
constexpr index_t GredThreadBufferLength = CK_PARAM_THREAD_BUFFER_LENGTH; // tunable
|
||||
|
||||
extern "C" __global__ void
|
||||
gridwise_generic_reduce_2_prepare(int GridSize, int BlkGroupSize, void* __restrict__ ws_global)
|
||||
{
|
||||
(void)BlkGroupSize;
|
||||
|
||||
void* p_src2dDesc = ws_global;
|
||||
void* p_dst1dDesc = static_cast<char*>(ws_global) + 2048;
|
||||
|
||||
const auto tupleDstLengths = make_tuple(1);
|
||||
const auto tupleDstStrides = make_tuple(1);
|
||||
|
||||
auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
|
||||
|
||||
const index_t invariantLen = dstDesc.GetLength(Number<0>{});
|
||||
const index_t toReduceLen = BlkGroupSize;
|
||||
|
||||
auto src2dDesc = make_naive_tensor_descriptor_packed(make_tuple(invariantLen, toReduceLen));
|
||||
|
||||
constexpr auto copySliceLen = GredThreadBufferLength;
|
||||
|
||||
if constexpr(src2d_need_padding)
|
||||
{
|
||||
const auto srcPad1 = GridSize * BlockSize - invariantLen;
|
||||
const auto srcPad2 =
|
||||
((toReduceLen + copySliceLen - 1) / copySliceLen) * copySliceLen - toReduceLen;
|
||||
auto src2dDesc_2 =
|
||||
transform_tensor_descriptor(src2dDesc,
|
||||
make_tuple(make_pad_transform(invariantLen, 0, srcPad1),
|
||||
make_pad_transform(toReduceLen, 0, srcPad2)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(src2dDesc_2)*>(p_src2dDesc) = src2dDesc_2;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(src2dDesc)*>(p_src2dDesc) = src2dDesc;
|
||||
}
|
||||
|
||||
if constexpr(dst1d_need_padding)
|
||||
{
|
||||
const auto dstPad = GridSize * BlockSize - invariantLen;
|
||||
auto dst1dDesc_2 =
|
||||
transform_tensor_descriptor(dstDesc,
|
||||
make_tuple(make_pad_transform(invariantLen, 0, dstPad)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(dst1dDesc_2)*>(p_dst1dDesc) = dst1dDesc_2;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(dstDesc)*>(p_dst1dDesc) = dstDesc;
|
||||
}
|
||||
};
|
||||
|
||||
struct get_ref_desc_types
|
||||
{
|
||||
static constexpr auto ref_tupleDstLengths = make_tuple(8);
|
||||
static constexpr auto ref_dstDesc =
|
||||
make_naive_tensor_descriptor(ref_tupleDstLengths, ref_tupleDstLengths);
|
||||
|
||||
static constexpr index_t ref_invariantLen = ref_dstDesc.GetLength(Number<0>{});
|
||||
static constexpr index_t ref_toReduceLen = 8;
|
||||
|
||||
static constexpr auto ref_src2dDesc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(ref_invariantLen, ref_toReduceLen));
|
||||
|
||||
using refType_src2dDesc = decltype(ref_src2dDesc);
|
||||
using refType_dst1dDesc = decltype(ref_dstDesc);
|
||||
|
||||
// used by the DirectThreadWise and DirectWarpWise method
|
||||
using refType_src2dDesc_padded_12 =
|
||||
decltype(transform_tensor_descriptor(ref_src2dDesc,
|
||||
make_tuple(make_pad_transform(ref_invariantLen, 0, 2),
|
||||
make_pad_transform(ref_toReduceLen, 0, 2)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{})));
|
||||
|
||||
using refType_dst1dDesc_padded =
|
||||
decltype(transform_tensor_descriptor(ref_dstDesc,
|
||||
make_tuple(make_pad_transform(ref_invariantLen, 0, 2)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{})));
|
||||
};
|
||||
|
||||
using refType_src2dDesc = typename get_ref_desc_types::refType_src2dDesc;
|
||||
using refType_dst1dDesc = typename get_ref_desc_types::refType_dst1dDesc;
|
||||
using refType_src2dDesc_padded_12 = typename get_ref_desc_types::refType_src2dDesc_padded_12;
|
||||
using refType_dst1dDesc_padded = typename get_ref_desc_types::refType_dst1dDesc_padded;
|
||||
|
||||
template <bool need_padding>
|
||||
static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc)
|
||||
{
|
||||
if constexpr(need_padding)
|
||||
return (*reinterpret_cast<const refType_src2dDesc_padded_12*>(p_src2dDesc));
|
||||
else
|
||||
return (*reinterpret_cast<const refType_src2dDesc*>(p_src2dDesc));
|
||||
};
|
||||
|
||||
template <bool need_padding>
|
||||
static __device__ auto get_reduction_dst1d_descriptor(const void* p_dst1dDesc)
|
||||
{
|
||||
if constexpr(need_padding)
|
||||
return (*reinterpret_cast<const refType_dst1dDesc_padded*>(p_dst1dDesc));
|
||||
else
|
||||
return (*reinterpret_cast<const refType_dst1dDesc*>(p_dst1dDesc));
|
||||
};
|
||||
|
||||
extern "C" __global__ void gridwise_generic_reduce_2(int origReduceLen,
|
||||
float alpha,
|
||||
const void* __restrict__ p_src_global,
|
||||
float beta,
|
||||
void* __restrict__ p_dst_global,
|
||||
const void CONSTANT* ws_global,
|
||||
long ws_buf2_bytes_offset,
|
||||
void* __restrict__ indices_global)
|
||||
{
|
||||
(void)p_src_global;
|
||||
|
||||
const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global);
|
||||
const void* p_dst1dDesc = static_cast<const char*>(p_src2dDesc) + 2048;
|
||||
void* ws_buf1_global = const_cast<char*>(static_cast<const char*>(p_src2dDesc) + 4096);
|
||||
|
||||
const auto src2dDesc = get_reduction_src2d_descriptor<src2d_need_padding>(p_src2dDesc);
|
||||
const auto dst1dDesc = get_reduction_dst1d_descriptor<dst1d_need_padding>(p_dst1dDesc);
|
||||
|
||||
using gridwise_2d_reduce = GridwiseReduction_xy_to_x_direct_threadwise<BlockSize,
|
||||
srcDataType,
|
||||
dstDataType,
|
||||
compType,
|
||||
decltype(src2dDesc),
|
||||
decltype(dst1dDesc),
|
||||
op,
|
||||
nanPropaOpt,
|
||||
reduceIndicesOpt,
|
||||
false,
|
||||
true,
|
||||
GredThreadBufferLength>;
|
||||
|
||||
void* const ws_buf2_global =
|
||||
ws_buf2_bytes_offset > 0
|
||||
? static_cast<void*>(static_cast<char*>(ws_buf1_global) + ws_buf2_bytes_offset)
|
||||
: nullptr;
|
||||
|
||||
constexpr int RunId = need_indices ? 3 : 1;
|
||||
gridwise_2d_reduce::template Run<RunId>(
|
||||
src2dDesc,
|
||||
dst1dDesc,
|
||||
origReduceLen,
|
||||
alpha,
|
||||
static_cast<const srcDataType* const __restrict__>(ws_buf1_global),
|
||||
beta,
|
||||
static_cast<dstDataType* const __restrict__>(p_dst_global),
|
||||
static_cast<const int* const __restrict__>(ws_buf2_global),
|
||||
static_cast<int* const __restrict__>(indices_global));
|
||||
};
|
||||
@@ -42,12 +42,8 @@ using compType =
|
||||
|
||||
constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable
|
||||
|
||||
constexpr index_t srcDims = CK_PARAM_IN_DIMS;
|
||||
constexpr index_t dstDims = CK_PARAM_OUT_DIMS;
|
||||
|
||||
using toReduceDims = Sequence<CK_PARAM_TOREDUCE_DIMS>;
|
||||
using invariantDims = Sequence<CK_PARAM_INVARIANT_DIMS>; // this could be empty
|
||||
|
||||
constexpr ReduceTensorOp_t op = static_cast<ReduceTensorOp_t>(CK_PARAM_REDUCE_OP);
|
||||
constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0
|
||||
? NanPropagation_t::NOT_PROPAGATE_NAN
|
||||
@@ -59,16 +55,6 @@ constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0
|
||||
constexpr bool src2d_need_padding = static_cast<bool>(CK_PARAM_SRC2D_PADDING);
|
||||
constexpr bool dst1d_need_padding = static_cast<bool>(CK_PARAM_DST1D_PADDING);
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////
|
||||
using specDims = typename sequence_merge<invariantDims, toReduceDims>::type;
|
||||
|
||||
static_assert(is_valid_sequence_map<specDims>::value && specDims::Size() == srcDims,
|
||||
"Wrong invariant and/or toReduce dimensions!");
|
||||
|
||||
// The number of invariant dimensions can be zero if all dimension are to be reduced
|
||||
static_assert(invariantDims::Size() > 0 || dstDims == 1,
|
||||
"If all source dimensions are reduced, the dest should have only one dimension !!");
|
||||
|
||||
constexpr bool indexable = reduce_binary_operator<compType, op>::indexable;
|
||||
constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES);
|
||||
|
||||
@@ -152,12 +138,12 @@ extern "C" __global__ void gridwise_generic_reduce_2_prepare(int GridSize,
|
||||
make_pad_transform(toReduceLen, 0, srcPad2)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
if(hipThreadIdx_x == 0)
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(src2dDesc_2)*>(p_src2dDesc) = src2dDesc_2;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(hipThreadIdx_x == 0)
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(src2dDesc)*>(p_src2dDesc) = src2dDesc;
|
||||
}
|
||||
|
||||
@@ -169,17 +155,17 @@ extern "C" __global__ void gridwise_generic_reduce_2_prepare(int GridSize,
|
||||
make_tuple(make_pad_transform(invariantLen, 0, dstPad)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
if(hipThreadIdx_x == 0)
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(dst1dDesc_2)*>(p_dst1dDesc) = dst1dDesc_2;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(hipThreadIdx_x == 0)
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(dst1dDesc)*>(p_dst1dDesc) = dst1dDesc;
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t srcDims, index_t dstDims, typename invariantDims, typename toReduceDims>
|
||||
template <index_t dstDims>
|
||||
struct get_ref_desc_types
|
||||
{
|
||||
static constexpr auto ref_tupleDstLengths =
|
||||
@@ -217,16 +203,11 @@ struct get_ref_desc_types
|
||||
make_tuple(Sequence<0>{})));
|
||||
};
|
||||
|
||||
using refType_src2dDesc =
|
||||
typename get_ref_desc_types<srcDims, dstDims, invariantDims, toReduceDims>::refType_src2dDesc;
|
||||
using refType_dst1dDesc =
|
||||
typename get_ref_desc_types<srcDims, dstDims, invariantDims, toReduceDims>::refType_dst1dDesc;
|
||||
using refType_src2dDesc = typename get_ref_desc_types<dstDims>::refType_src2dDesc;
|
||||
using refType_dst1dDesc = typename get_ref_desc_types<dstDims>::refType_dst1dDesc;
|
||||
using refType_src2dDesc_padded_12 =
|
||||
typename get_ref_desc_types<srcDims, dstDims, invariantDims, toReduceDims>::
|
||||
refType_src2dDesc_padded_12;
|
||||
using refType_dst1dDesc_padded =
|
||||
typename get_ref_desc_types<srcDims, dstDims, invariantDims, toReduceDims>::
|
||||
refType_dst1dDesc_padded;
|
||||
typename get_ref_desc_types<dstDims>::refType_src2dDesc_padded_12;
|
||||
using refType_dst1dDesc_padded = typename get_ref_desc_types<dstDims>::refType_dst1dDesc_padded;
|
||||
|
||||
template <bool need_padding>
|
||||
static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc)
|
||||
@@ -251,15 +232,15 @@ extern "C" __global__ void gridwise_generic_reduce_2(int origReduceLen,
|
||||
const void* __restrict__ p_src_global,
|
||||
float beta,
|
||||
void* __restrict__ p_dst_global,
|
||||
void* __restrict__ ws_global,
|
||||
const void CONSTANT* ws_global,
|
||||
long ws_buf2_bytes_offset,
|
||||
void* __restrict__ indices_global)
|
||||
{
|
||||
(void)p_src_global;
|
||||
|
||||
const void* p_src2dDesc = ws_global;
|
||||
const void* p_dst1dDesc = static_cast<char*>(ws_global) + 2048;
|
||||
void* ws_buf1_global = static_cast<char*>(ws_global) + 4096;
|
||||
const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global);
|
||||
const void* p_dst1dDesc = static_cast<const char*>(p_src2dDesc) + 2048;
|
||||
void* ws_buf1_global = const_cast<char*>(static_cast<const char*>(p_src2dDesc) + 4096);
|
||||
|
||||
const auto src2dDesc = get_reduction_src2d_descriptor<src2d_need_padding>(p_src2dDesc);
|
||||
const auto dst1dDesc = get_reduction_dst1d_descriptor<dst1d_need_padding>(p_dst1dDesc);
|
||||
@@ -0,0 +1,221 @@
|
||||
/*******************************************************************************
|
||||
*
|
||||
* MIT License
|
||||
*
|
||||
* Copyright (c) 2021 Advanced Micro Devices, Inc.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in all
|
||||
* copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
* SOFTWARE.
|
||||
*
|
||||
*******************************************************************************/
|
||||
#include "config.hpp"
|
||||
#include "number.hpp"
|
||||
#include "sequence.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "data_type_enum_helper.hpp"
|
||||
#include "reduction_common.hpp"
|
||||
#include "gridwise_generic_2d_reduction_direct_warpwise.hpp"
|
||||
|
||||
using namespace ck;
|
||||
|
||||
using srcDataType =
|
||||
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_SRC_DATATYPE)>::type;
|
||||
using dstDataType =
|
||||
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_DST_DATATYPE)>::type;
|
||||
using compType =
|
||||
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_REDUCE_COMPTYPE)>::type;
|
||||
|
||||
constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable
|
||||
|
||||
constexpr ReduceTensorOp_t op = static_cast<ReduceTensorOp_t>(CK_PARAM_REDUCE_OP);
|
||||
constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0
|
||||
? NanPropagation_t::NOT_PROPAGATE_NAN
|
||||
: NanPropagation_t::PROPAGATE_NAN;
|
||||
constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0
|
||||
? ReduceTensorIndices_t::NO_INDICES
|
||||
: ReduceTensorIndices_t::FLATTENED_INDICES;
|
||||
|
||||
constexpr bool src2d_need_padding = static_cast<bool>(CK_PARAM_SRC2D_PADDING);
|
||||
constexpr bool dst1d_need_padding = static_cast<bool>(CK_PARAM_DST1D_PADDING);
|
||||
|
||||
constexpr bool indexable = reduce_binary_operator<compType, op>::indexable;
|
||||
constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES);
|
||||
|
||||
constexpr index_t GredAccessesPerThreadInWarp = CK_PARAM_ACCESSES_PER_THREAD_INWARP; // tunable
|
||||
|
||||
extern "C" __global__ void
|
||||
gridwise_generic_reduce_2_prepare(int GridSize, int BlkGroupSize, void* __restrict__ ws_global)
|
||||
{
|
||||
(void)BlkGroupSize;
|
||||
|
||||
void* p_src2dDesc = ws_global;
|
||||
void* p_dst1dDesc = static_cast<char*>(ws_global) + 2048;
|
||||
|
||||
const auto tupleDstLengths = make_tuple(1);
|
||||
const auto tupleDstStrides = make_tuple(1);
|
||||
|
||||
auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
|
||||
|
||||
const index_t invariantLen = dstDesc.GetLength(Number<0>{});
|
||||
const index_t toReduceLen = BlkGroupSize;
|
||||
|
||||
auto src2dDesc = make_naive_tensor_descriptor_packed(make_tuple(invariantLen, toReduceLen));
|
||||
|
||||
constexpr auto copySliceLen = warpSize * GredAccessesPerThreadInWarp;
|
||||
|
||||
if constexpr(src2d_need_padding)
|
||||
{
|
||||
const auto srcPad1 = GridSize * BlockSize / warpSize - invariantLen;
|
||||
const auto srcPad2 =
|
||||
((toReduceLen + copySliceLen - 1) / copySliceLen) * copySliceLen - toReduceLen;
|
||||
|
||||
auto src2dDesc_2 =
|
||||
transform_tensor_descriptor(src2dDesc,
|
||||
make_tuple(make_pad_transform(invariantLen, 0, srcPad1),
|
||||
make_pad_transform(toReduceLen, 0, srcPad2)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(src2dDesc_2)*>(p_src2dDesc) = src2dDesc_2;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(src2dDesc)*>(p_src2dDesc) = src2dDesc;
|
||||
}
|
||||
|
||||
if constexpr(dst1d_need_padding)
|
||||
{
|
||||
const auto dstPad = GridSize * BlockSize / warpSize - invariantLen;
|
||||
auto dst1dDesc_2 =
|
||||
transform_tensor_descriptor(dstDesc,
|
||||
make_tuple(make_pad_transform(invariantLen, 0, dstPad)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(dst1dDesc_2)*>(p_dst1dDesc) = dst1dDesc_2;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(dstDesc)*>(p_dst1dDesc) = dstDesc;
|
||||
}
|
||||
};
|
||||
|
||||
struct get_ref_desc_types
|
||||
{
|
||||
static constexpr auto ref_tupleDstLengths = make_tuple(8);
|
||||
static constexpr auto ref_dstDesc =
|
||||
make_naive_tensor_descriptor(ref_tupleDstLengths, ref_tupleDstLengths);
|
||||
|
||||
static constexpr index_t ref_invariantLen = ref_dstDesc.GetLength(Number<0>{});
|
||||
static constexpr index_t ref_toReduceLen = 8;
|
||||
|
||||
static constexpr auto ref_src2dDesc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(ref_invariantLen, ref_toReduceLen));
|
||||
|
||||
using refType_src2dDesc = decltype(ref_src2dDesc);
|
||||
using refType_dst1dDesc = decltype(ref_dstDesc);
|
||||
|
||||
// used by the DirectThreadWise and DirectWarpWise method
|
||||
using refType_src2dDesc_padded_12 =
|
||||
decltype(transform_tensor_descriptor(ref_src2dDesc,
|
||||
make_tuple(make_pad_transform(ref_invariantLen, 0, 2),
|
||||
make_pad_transform(ref_toReduceLen, 0, 2)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{})));
|
||||
|
||||
using refType_dst1dDesc_padded =
|
||||
decltype(transform_tensor_descriptor(ref_dstDesc,
|
||||
make_tuple(make_pad_transform(ref_invariantLen, 0, 2)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{})));
|
||||
};
|
||||
|
||||
using refType_src2dDesc = typename get_ref_desc_types::refType_src2dDesc;
|
||||
using refType_dst1dDesc = typename get_ref_desc_types::refType_dst1dDesc;
|
||||
using refType_src2dDesc_padded_12 = typename get_ref_desc_types::refType_src2dDesc_padded_12;
|
||||
using refType_dst1dDesc_padded = typename get_ref_desc_types::refType_dst1dDesc_padded;
|
||||
|
||||
template <bool need_padding>
|
||||
static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc)
|
||||
{
|
||||
if constexpr(need_padding)
|
||||
return (*reinterpret_cast<const refType_src2dDesc_padded_12*>(p_src2dDesc));
|
||||
else
|
||||
return (*reinterpret_cast<const refType_src2dDesc*>(p_src2dDesc));
|
||||
};
|
||||
|
||||
template <bool need_padding>
|
||||
static __device__ auto get_reduction_dst1d_descriptor(const void* p_dst1dDesc)
|
||||
{
|
||||
if constexpr(need_padding)
|
||||
return (*reinterpret_cast<const refType_dst1dDesc_padded*>(p_dst1dDesc));
|
||||
else
|
||||
return (*reinterpret_cast<const refType_dst1dDesc*>(p_dst1dDesc));
|
||||
};
|
||||
|
||||
extern "C" __global__ void gridwise_generic_reduce_2(int origReduceLen,
|
||||
float alpha,
|
||||
const void* __restrict__ p_src_global,
|
||||
float beta,
|
||||
void* __restrict__ p_dst_global,
|
||||
const void CONSTANT* ws_global,
|
||||
long ws_buf2_bytes_offset,
|
||||
void* __restrict__ indices_global)
|
||||
{
|
||||
(void)p_src_global;
|
||||
|
||||
const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global);
|
||||
const void* p_dst1dDesc = static_cast<const char*>(p_src2dDesc) + 2048;
|
||||
void* ws_buf1_global = const_cast<char*>(static_cast<const char*>(p_src2dDesc) + 4096);
|
||||
|
||||
const auto src2dDesc = get_reduction_src2d_descriptor<src2d_need_padding>(p_src2dDesc);
|
||||
const auto dst1dDesc = get_reduction_dst1d_descriptor<dst1d_need_padding>(p_dst1dDesc);
|
||||
|
||||
using gridwise_2d_reduce =
|
||||
GridwiseReduction_xy_to_x_direct_warpwise<BlockSize,
|
||||
srcDataType,
|
||||
dstDataType,
|
||||
compType,
|
||||
decltype(src2dDesc),
|
||||
decltype(dst1dDesc),
|
||||
op,
|
||||
nanPropaOpt,
|
||||
reduceIndicesOpt,
|
||||
false,
|
||||
true,
|
||||
GredAccessesPerThreadInWarp>;
|
||||
|
||||
void* const ws_buf2_global =
|
||||
ws_buf2_bytes_offset > 0
|
||||
? static_cast<void*>(static_cast<char*>(ws_buf1_global) + ws_buf2_bytes_offset)
|
||||
: nullptr;
|
||||
|
||||
constexpr int RunId = need_indices ? 3 : 1;
|
||||
gridwise_2d_reduce::template Run<RunId>(
|
||||
src2dDesc,
|
||||
dst1dDesc,
|
||||
origReduceLen,
|
||||
alpha,
|
||||
static_cast<const srcDataType* const __restrict__>(ws_buf1_global),
|
||||
beta,
|
||||
static_cast<dstDataType* const __restrict__>(p_dst_global),
|
||||
static_cast<const int* const __restrict__>(ws_buf2_global),
|
||||
static_cast<int* const __restrict__>(indices_global));
|
||||
};
|
||||
@@ -42,12 +42,8 @@ using compType =
|
||||
|
||||
constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable
|
||||
|
||||
constexpr index_t srcDims = CK_PARAM_IN_DIMS;
|
||||
constexpr index_t dstDims = CK_PARAM_OUT_DIMS;
|
||||
|
||||
using toReduceDims = Sequence<CK_PARAM_TOREDUCE_DIMS>;
|
||||
using invariantDims = Sequence<CK_PARAM_INVARIANT_DIMS>; // this could be empty
|
||||
|
||||
constexpr ReduceTensorOp_t op = static_cast<ReduceTensorOp_t>(CK_PARAM_REDUCE_OP);
|
||||
constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0
|
||||
? NanPropagation_t::NOT_PROPAGATE_NAN
|
||||
@@ -59,16 +55,6 @@ constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0
|
||||
constexpr bool src2d_need_padding = static_cast<bool>(CK_PARAM_SRC2D_PADDING);
|
||||
constexpr bool dst1d_need_padding = static_cast<bool>(CK_PARAM_DST1D_PADDING);
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////
|
||||
using specDims = typename sequence_merge<invariantDims, toReduceDims>::type;
|
||||
|
||||
static_assert(is_valid_sequence_map<specDims>::value && specDims::Size() == srcDims,
|
||||
"Wrong invariant and/or toReduce dimensions!");
|
||||
|
||||
// The number of invariant dimensions can be zero if all dimension are to be reduced
|
||||
static_assert(invariantDims::Size() > 0 || dstDims == 1,
|
||||
"If all source dimensions are reduced, the dest should have only one dimension !!");
|
||||
|
||||
constexpr bool indexable = reduce_binary_operator<compType, op>::indexable;
|
||||
constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES);
|
||||
|
||||
@@ -153,12 +139,12 @@ extern "C" __global__ void gridwise_generic_reduce_2_prepare(int GridSize,
|
||||
make_pad_transform(toReduceLen, 0, srcPad2)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
if(hipThreadIdx_x == 0)
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(src2dDesc_2)*>(p_src2dDesc) = src2dDesc_2;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(hipThreadIdx_x == 0)
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(src2dDesc)*>(p_src2dDesc) = src2dDesc;
|
||||
}
|
||||
|
||||
@@ -170,17 +156,17 @@ extern "C" __global__ void gridwise_generic_reduce_2_prepare(int GridSize,
|
||||
make_tuple(make_pad_transform(invariantLen, 0, dstPad)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
if(hipThreadIdx_x == 0)
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(dst1dDesc_2)*>(p_dst1dDesc) = dst1dDesc_2;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(hipThreadIdx_x == 0)
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
*static_cast<decltype(dst1dDesc)*>(p_dst1dDesc) = dst1dDesc;
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t srcDims, index_t dstDims, typename invariantDims, typename toReduceDims>
|
||||
template <index_t dstDims>
|
||||
struct get_ref_desc_types
|
||||
{
|
||||
static constexpr auto ref_tupleDstLengths =
|
||||
@@ -218,16 +204,11 @@ struct get_ref_desc_types
|
||||
make_tuple(Sequence<0>{})));
|
||||
};
|
||||
|
||||
using refType_src2dDesc =
|
||||
typename get_ref_desc_types<srcDims, dstDims, invariantDims, toReduceDims>::refType_src2dDesc;
|
||||
using refType_dst1dDesc =
|
||||
typename get_ref_desc_types<srcDims, dstDims, invariantDims, toReduceDims>::refType_dst1dDesc;
|
||||
using refType_src2dDesc = typename get_ref_desc_types<dstDims>::refType_src2dDesc;
|
||||
using refType_dst1dDesc = typename get_ref_desc_types<dstDims>::refType_dst1dDesc;
|
||||
using refType_src2dDesc_padded_12 =
|
||||
typename get_ref_desc_types<srcDims, dstDims, invariantDims, toReduceDims>::
|
||||
refType_src2dDesc_padded_12;
|
||||
using refType_dst1dDesc_padded =
|
||||
typename get_ref_desc_types<srcDims, dstDims, invariantDims, toReduceDims>::
|
||||
refType_dst1dDesc_padded;
|
||||
typename get_ref_desc_types<dstDims>::refType_src2dDesc_padded_12;
|
||||
using refType_dst1dDesc_padded = typename get_ref_desc_types<dstDims>::refType_dst1dDesc_padded;
|
||||
|
||||
template <bool need_padding>
|
||||
static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc)
|
||||
@@ -252,15 +233,15 @@ extern "C" __global__ void gridwise_generic_reduce_2(int origReduceLen,
|
||||
const void* __restrict__ p_src_global,
|
||||
float beta,
|
||||
void* __restrict__ p_dst_global,
|
||||
void* __restrict__ ws_global,
|
||||
const void CONSTANT* ws_global,
|
||||
long ws_buf2_bytes_offset,
|
||||
void* __restrict__ indices_global)
|
||||
{
|
||||
(void)p_src_global;
|
||||
|
||||
const void* p_src2dDesc = ws_global;
|
||||
const void* p_dst1dDesc = static_cast<char*>(ws_global) + 2048;
|
||||
void* ws_buf1_global = static_cast<char*>(ws_global) + 4096;
|
||||
const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global);
|
||||
const void* p_dst1dDesc = static_cast<const char*>(p_src2dDesc) + 2048;
|
||||
void* ws_buf1_global = const_cast<char*>(static_cast<const char*>(p_src2dDesc) + 4096);
|
||||
|
||||
const auto src2dDesc = get_reduction_src2d_descriptor<src2d_need_padding>(p_src2dDesc);
|
||||
const auto dst1dDesc = get_reduction_dst1d_descriptor<dst1d_need_padding>(p_dst1dDesc);
|
||||
Reference in New Issue
Block a user