mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Standalone softmax kernel (#284)
* initial stub for standalone softmax * start device_softmax_mk_to_mk as a wrapper to device_reduce_mk_to_m * host softmax validates * compiles; to implement beta scaling * use NaN trick to efficiently ignore OOB values during sum of exponentials * freeload device_reduce's utility functions * clean up interface * adding prior value (beta scaling) * remove restriction related to perf considerations * apply clang-format * clean; disable diagnostics * resolve conflicts * add exp wrapper * honor HostTensorDesc interface; allow implicit cast from different vector<T> type * test softmax for fp16/fp32 * update readme * amend commit NaN trick * remove redundant param added during development * format * replace ScalarDataType with AccDataType * separate out test programs by precision type * move softmax sample code to its own folder * format * keep up with recent changes in reduction API * remove extra header
This commit is contained in:
@@ -5,14 +5,14 @@
|
||||
# -D <xxx> : input 4-d tensor lengths
|
||||
# -v <x> : verification (0=no, 1=yes)
|
||||
#arg1: initialization (0=no init, 1=single integer value, 2=scope integer value, 3=decimal value)
|
||||
#arg2: time kernel (0=no, 1=yes)
|
||||
#arg2: time kernel (0=no, 1=yes)
|
||||
./bin/example_reduce_blockwise -D 16,64,32,960 -v 1 1 1
|
||||
```
|
||||
|
||||
Result
|
||||
```
|
||||
./bin/example_reduce_blockwise -D 16,64,32,960 -v 1 1 1
|
||||
launch_and_time_kernel: grid_dim {240, 1, 1}, block_dim {256, 1, 1}
|
||||
launch_and_time_kernel: grid_dim {240, 1, 1}, block_dim {256, 1, 1}
|
||||
Warm up 1 time
|
||||
Start running 10 times...
|
||||
Perf: 0.282592 ms, 222.641 GB/s, DeviceReduceBlockWise<256,M_C4_S1,K_C64_S1,InSrcVectorDim_0_InSrcVectorSize_1_OutDstVectorSize_1>
|
||||
@@ -24,19 +24,18 @@ Perf: 0.282592 ms, 222.641 GB/s, DeviceReduceBlockWise<256,M_C4_S1,K_C64_S1,InSr
|
||||
```bash
|
||||
#arg1: verification (0=no, 1=yes(
|
||||
#arg2: initialization (0=no init, 1=single integer value, 2=scope integer value, 3=decimal value)
|
||||
#arg3: time kernel (0=no, 1=yes)
|
||||
#arg3: time kernel (0=no, 1=yes)
|
||||
./bin/example_reduce_blockwise_two_call 1 2 1
|
||||
|
||||
```
|
||||
|
||||
Result
|
||||
```
|
||||
./bin/example_reduce_blockwise_two_call 1 2 1
|
||||
launch_and_time_kernel: grid_dim {204800, 1, 1}, block_dim {256, 1, 1}
|
||||
launch_and_time_kernel: grid_dim {204800, 1, 1}, block_dim {256, 1, 1}
|
||||
Warm up 1 time
|
||||
Start running 10 times...
|
||||
launch_and_time_kernel: grid_dim {6400, 1, 1}, block_dim {256, 1, 1}
|
||||
launch_and_time_kernel: grid_dim {6400, 1, 1}, block_dim {256, 1, 1}
|
||||
Warm up 1 time
|
||||
Start running 10 times...
|
||||
Perf: 2.1791 ms, 771.42 GB/s, DeviceReduceBlockWise<256,M_C32_S1,K_C8_S1,InSrcVectorDim_1_InSrcVectorSize_1_OutDstVectorSize_1> => DeviceReduceBlockWise<256,M_C256_S1,K_C1_S1,InSrcVectorDim_1_InSrcVectorSize_1_OutDstVectorSize_1>
|
||||
```
|
||||
|
||||
|
||||
1
example/23_softmax/CMakeLists.txt
Normal file
1
example/23_softmax/CMakeLists.txt
Normal file
@@ -0,0 +1 @@
|
||||
add_example_executable(example_softmax_blockwise softmax_blockwise.cpp)
|
||||
18
example/23_softmax/README.md
Normal file
18
example/23_softmax/README.md
Normal file
@@ -0,0 +1,18 @@
|
||||
# Instructions for ```example_softmax_blockwise```
|
||||
|
||||
## Run ```example_softmax_blockwise```
|
||||
```bash
|
||||
# -D <xxx> : input 3-d tensor lengths
|
||||
# -v <x> : verification (0=no, 1=yes)
|
||||
#arg1: initialization (0=no init, 1=single integer value, 2=scope integer value, 3=decimal value)
|
||||
#arg2: time kernel (0=no, 1=yes)
|
||||
example_softmax_blockwise -D 4,128,2048 -v 1 1 1
|
||||
```
|
||||
|
||||
Result
|
||||
```
|
||||
launch_and_time_kernel: grid_dim {64, 1, 1}, block_dim {256, 1, 1}
|
||||
Warm up 1 time
|
||||
Start running 10 times...
|
||||
Perf: 0.0242877 ms, 259.039 GB/s, DeviceReduceSoftmax<256,M_C8_S1,K_C32_S8,InSrcVectorDim_1_InSrcVectorSize_8_OutDstVectorSize_8>
|
||||
```
|
||||
255
example/23_softmax/softmax_blockwise.cpp
Normal file
255
example/23_softmax/softmax_blockwise.cpp
Normal file
@@ -0,0 +1,255 @@
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
#include <getopt.h>
|
||||
|
||||
#include "check_err.hpp"
|
||||
#include "config.hpp"
|
||||
#include "print.hpp"
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "host_tensor_generator.hpp"
|
||||
#include "device_tensor.hpp"
|
||||
#include "device_base.hpp"
|
||||
#include "device_softmax.hpp"
|
||||
#include "host_common_util.hpp"
|
||||
#include "reference_softmax.hpp"
|
||||
|
||||
#include "reduction_enums.hpp"
|
||||
#include "reduction_operator_mapping.hpp"
|
||||
|
||||
using namespace ck;
|
||||
using namespace ck::tensor_operation::device;
|
||||
|
||||
using InDataType = ck::half_t;
|
||||
using OutDataType = ck::half_t;
|
||||
using AccDataType = float;
|
||||
|
||||
constexpr int Rank = 3;
|
||||
constexpr int NumReduceDim = 1;
|
||||
|
||||
using DeviceInstance = DeviceSoftmax<InDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
Rank,
|
||||
NumReduceDim,
|
||||
256, // BlockSize
|
||||
8, // ClusterM
|
||||
32, // ClusterK
|
||||
1, // SliceM
|
||||
8, // SliceK
|
||||
1, // SrcVecDim (0=M, 1=K)
|
||||
8, // SrcScalarPerVector
|
||||
8>; // OutScalarPerVector
|
||||
|
||||
static struct option long_options[] = {{"inLengths", required_argument, nullptr, 'D'},
|
||||
{"verify", required_argument, nullptr, 'v'},
|
||||
{"help", no_argument, nullptr, '?'},
|
||||
{nullptr, 0, nullptr, 0}};
|
||||
|
||||
class SimpleAppArgs
|
||||
{
|
||||
private:
|
||||
int option_index = 0;
|
||||
|
||||
public:
|
||||
std::vector<size_t> inLengths = {8, 128, 2048};
|
||||
std::vector<AccDataType> scales = {2.0f, 2.0f};
|
||||
|
||||
bool do_verification = true;
|
||||
int init_method = 2;
|
||||
bool time_kernel = true;
|
||||
|
||||
public:
|
||||
void show_usage(const char* cmd)
|
||||
{
|
||||
std::cout << "Usage of " << cmd << std::endl;
|
||||
std::cout << "--inLengths or -D, comma separated list of input tensor dimension lengths"
|
||||
<< std::endl;
|
||||
std::cout << "--verify or -v, 1/0 to indicate whether to verify the reduction result by "
|
||||
"comparing with the host-based reduction"
|
||||
<< std::endl;
|
||||
std::cout << "Arg1 -- init method (0=no init, 1=single integer value, 2=scope integer "
|
||||
"value, 3=decimal value)"
|
||||
<< std::endl;
|
||||
std::cout << "Arg2 -- time kernel (0=no, 1=yes)" << std::endl;
|
||||
};
|
||||
|
||||
int processArgs(int argc, char* argv[])
|
||||
{
|
||||
using ck::host_common::getTypeValuesFromString;
|
||||
|
||||
int ch;
|
||||
|
||||
while(1)
|
||||
{
|
||||
ch = getopt_long(argc, argv, "D:v:l:", long_options, &option_index);
|
||||
if(ch == -1)
|
||||
break;
|
||||
switch(ch)
|
||||
{
|
||||
case 'D':
|
||||
if(!optarg)
|
||||
throw std::runtime_error("Invalid option format!");
|
||||
|
||||
inLengths = getTypeValuesFromString<size_t>(optarg);
|
||||
break;
|
||||
case 'v':
|
||||
if(!optarg)
|
||||
throw std::runtime_error("Invalid option format!");
|
||||
|
||||
do_verification = static_cast<bool>(std::atoi(optarg));
|
||||
break;
|
||||
case '?':
|
||||
if(std::string(long_options[option_index].name) == "help")
|
||||
{
|
||||
show_usage(argv[0]);
|
||||
return (-1);
|
||||
};
|
||||
break;
|
||||
default: show_usage(argv[0]); return (-1);
|
||||
};
|
||||
};
|
||||
|
||||
if(optind + 2 > argc)
|
||||
throw std::runtime_error("Invalid cmd-line arguments, more argumetns are needed!");
|
||||
|
||||
init_method = std::atoi(argv[optind++]);
|
||||
time_kernel = static_cast<bool>(std::atoi(argv[optind]));
|
||||
|
||||
if(scales.empty())
|
||||
{
|
||||
scales.push_back(1.0f);
|
||||
scales.push_back(0.0f);
|
||||
};
|
||||
|
||||
return (0);
|
||||
};
|
||||
};
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
// Example: batched gemm C[G, M, N] applies max/sum reduction along N internally
|
||||
const std::vector<int> invariantDims{0, 1};
|
||||
const std::vector<int> reduceDims{2};
|
||||
|
||||
SimpleAppArgs args;
|
||||
|
||||
if(argc > 1)
|
||||
{
|
||||
if(args.processArgs(argc, argv) < 0)
|
||||
return (-1);
|
||||
};
|
||||
|
||||
Tensor<InDataType> in(args.inLengths);
|
||||
Tensor<OutDataType> out_ref(args.inLengths);
|
||||
Tensor<OutDataType> out(args.inLengths);
|
||||
|
||||
auto inStrides = in.mDesc.GetStrides();
|
||||
auto outStrides = out.mDesc.GetStrides();
|
||||
|
||||
AccDataType alpha = args.scales[0];
|
||||
AccDataType beta = args.scales[1];
|
||||
|
||||
std::size_t num_thread = 1;
|
||||
|
||||
if(args.do_verification)
|
||||
{
|
||||
switch(args.init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
in.GenerateTensorValue(GeneratorTensor_1<InDataType>{1}, num_thread);
|
||||
if(beta != 0.0f)
|
||||
out_ref.GenerateTensorValue(GeneratorTensor_1<OutDataType>{1}, num_thread);
|
||||
break;
|
||||
case 2:
|
||||
in.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5}, num_thread);
|
||||
if(beta != 0.0f)
|
||||
out_ref.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-5, 5}, num_thread);
|
||||
break;
|
||||
default:
|
||||
in.GenerateTensorValue(GeneratorTensor_3<InDataType>{-5.0, 5.0}, num_thread);
|
||||
if(beta != 0.0f)
|
||||
out_ref.GenerateTensorValue(GeneratorTensor_3<OutDataType>{-5.0, 5.0}, num_thread);
|
||||
}
|
||||
|
||||
if(beta != 0.0f)
|
||||
for(size_t i = 0; i < out_ref.mDesc.GetElementSpace(); i++)
|
||||
out.mData[i] = out_ref.mData[i];
|
||||
};
|
||||
// std::cout << "beta = " << beta << std::endl;
|
||||
// LogRangeAsType<float>(std::cout << "tensor in: " , in.mData, ",") << std::endl;
|
||||
// LogRangeAsType<float>(std::cout << "tensor prior out: " , out.mData, ",") << std::endl;
|
||||
|
||||
// these buffers are usually provided by the user application
|
||||
DeviceMem in_dev(sizeof(InDataType) * in.mDesc.GetElementSpace());
|
||||
DeviceMem out_dev(sizeof(OutDataType) * out.mDesc.GetElementSpace());
|
||||
|
||||
in_dev.ToDevice(in.mData.data());
|
||||
|
||||
if(beta != 0.0f)
|
||||
out_dev.ToDevice(out.mData.data());
|
||||
|
||||
if(args.do_verification)
|
||||
{
|
||||
using ReferenceInstance =
|
||||
tensor_operation::host::ReferenceSoftmax<InDataType, OutDataType, AccDataType>;
|
||||
ReferenceInstance ref;
|
||||
auto ref_arg = ref.MakeArgument(in, out_ref, alpha, beta, Rank, reduceDims);
|
||||
auto invoker = ref.MakeInvoker();
|
||||
invoker.Run(ref_arg);
|
||||
// LogRangeAsType<float>(std::cout << "tensor out_ref: ", out_ref.mData, ",") << std::endl;
|
||||
};
|
||||
|
||||
std::vector<ck::index_t> i_inLengths;
|
||||
std::vector<ck::index_t> i_inStrides;
|
||||
|
||||
i_inLengths.assign(args.inLengths.begin(), args.inLengths.end());
|
||||
i_inStrides.assign(inStrides.begin(), inStrides.end());
|
||||
|
||||
auto device_instance = DeviceInstance{};
|
||||
|
||||
auto argument_ptr = device_instance.MakeArgumentPointer(i_inLengths,
|
||||
i_inStrides,
|
||||
reduceDims,
|
||||
alpha,
|
||||
beta,
|
||||
in_dev.GetDeviceBuffer(),
|
||||
out_dev.GetDeviceBuffer());
|
||||
|
||||
if(!device_instance.IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
std::cout
|
||||
<< "The runtime parameters seems not supported by the DeviceReduce instance, exiting!"
|
||||
<< std::endl;
|
||||
return 1;
|
||||
};
|
||||
|
||||
std::string instance_name = device_instance.GetTypeString();
|
||||
|
||||
auto invoker_ptr = device_instance.MakeInvokerPointer();
|
||||
|
||||
bool pass = true;
|
||||
if(args.do_verification)
|
||||
{
|
||||
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false});
|
||||
out_dev.FromDevice(out.mData.data());
|
||||
// LogRangeAsType<float>(std::cout << "tensor out: " , out.mData, ",") << std::endl;
|
||||
pass = pass && ck::utils::check_err(out.mData, out_ref.mData);
|
||||
};
|
||||
|
||||
float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, args.time_kernel});
|
||||
|
||||
std::size_t num_bytes =
|
||||
in.mDesc.GetElementSize() * sizeof(InDataType) +
|
||||
(beta == 0.0f ? 1 : 2) * out.mDesc.GetElementSize() * sizeof(OutDataType);
|
||||
|
||||
float gb_per_sec = num_bytes / 1.E6 / avg_time;
|
||||
|
||||
std::cout << "Perf: " << avg_time << " ms, " << gb_per_sec << " GB/s, " << instance_name
|
||||
<< std::endl;
|
||||
|
||||
return (pass ? 0 : 1);
|
||||
}
|
||||
@@ -56,3 +56,4 @@ add_subdirectory(19_binary_elementwise)
|
||||
add_subdirectory(20_convnd_bwd_weight_xdl)
|
||||
add_subdirectory(21_gemm_layernorm)
|
||||
add_subdirectory(22_cgemm)
|
||||
add_subdirectory(23_softmax)
|
||||
|
||||
@@ -45,7 +45,9 @@ template <typename AccDataType,
|
||||
typename ThreadClusterLengths_M_K,
|
||||
typename ThreadClusterArrangeOrder,
|
||||
typename OpReduce,
|
||||
bool PropagateNan>
|
||||
bool PropagateNan,
|
||||
typename Accumulation =
|
||||
detail::AccumulateWithNanCheck<PropagateNan, OpReduce, AccDataType>>
|
||||
struct PartitionedBlockwiseReduction
|
||||
{
|
||||
static_assert(BlockSize == ThreadClusterLengths_M_K::At(0) * ThreadClusterLengths_M_K::At(1),
|
||||
@@ -62,8 +64,6 @@ struct PartitionedBlockwiseReduction
|
||||
static constexpr auto thread_cluster_desc =
|
||||
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
|
||||
|
||||
using Accumulation = detail::AccumulateWithNanCheck<PropagateNan, OpReduce, AccDataType>;
|
||||
|
||||
template <typename BufferType>
|
||||
__device__ static void Reduce(BufferType& work_buffer, AccDataType& in_out_value)
|
||||
{
|
||||
@@ -113,13 +113,16 @@ struct PartitionedBlockwiseReduction
|
||||
// 3) in_out_value/in_out_index is the input data in vgpr from each thread
|
||||
// 4) in_out_value/in_out_index is the over-written reduced output in vgpr for each thread
|
||||
// clang-format on
|
||||
template <typename AccDataType,
|
||||
typename IndexDataType,
|
||||
index_t BlockSize,
|
||||
typename ThreadClusterLengths_M_K,
|
||||
typename ThreadClusterArrangeOrder,
|
||||
typename OpReduce,
|
||||
bool PropagateNan>
|
||||
template <
|
||||
typename AccDataType,
|
||||
typename IndexDataType,
|
||||
index_t BlockSize,
|
||||
typename ThreadClusterLengths_M_K,
|
||||
typename ThreadClusterArrangeOrder,
|
||||
typename OpReduce,
|
||||
bool PropagateNan,
|
||||
typename Accumulation =
|
||||
detail::AccumulateWithIndexAndNanCheck<PropagateNan, OpReduce, AccDataType, IndexDataType>>
|
||||
struct PartitionedBlockwiseReductionWithIndex
|
||||
{
|
||||
static_assert(BlockSize == ThreadClusterLengths_M_K::At(0) * ThreadClusterLengths_M_K::At(1),
|
||||
@@ -136,9 +139,6 @@ struct PartitionedBlockwiseReductionWithIndex
|
||||
static constexpr auto thread_cluster_desc =
|
||||
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
|
||||
|
||||
using Accumulation =
|
||||
detail::AccumulateWithIndexAndNanCheck<PropagateNan, OpReduce, AccDataType, IndexDataType>;
|
||||
|
||||
// This interface accumulates on both data values and indices
|
||||
template <typename BufferType, typename IdxBufferType>
|
||||
__device__ static void Reduce(BufferType& work_val_buffer,
|
||||
|
||||
@@ -390,10 +390,8 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
|
||||
};
|
||||
};
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
static bool IsSupportedArgument(const Argument* pArg)
|
||||
{
|
||||
const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
|
||||
|
||||
if constexpr(use_multiblock)
|
||||
{
|
||||
if(static_cast<float>(pArg->beta_) != 0.0f)
|
||||
@@ -442,11 +440,16 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
|
||||
else
|
||||
{
|
||||
// cases with very small reduce_total_length should be handled by ThreadWise kernel
|
||||
if(pArg->reduce_total_length / KThreadSliceSize < 2)
|
||||
return (false);
|
||||
// if(pArg->reduce_total_length / KThreadSliceSize < 2)
|
||||
// return (false);
|
||||
};
|
||||
|
||||
return (true);
|
||||
}
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(dynamic_cast<const Argument*>(p_arg));
|
||||
};
|
||||
|
||||
std::unique_ptr<BaseArgument>
|
||||
|
||||
203
include/ck/tensor_operation/gpu/device/device_softmax.hpp
Normal file
203
include/ck/tensor_operation/gpu/device/device_softmax.hpp
Normal file
@@ -0,0 +1,203 @@
|
||||
#ifndef DEVICE_SOFTMAX_HPP
|
||||
#define DEVICE_SOFTMAX_HPP
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include "device.hpp"
|
||||
#include "device_base.hpp"
|
||||
#include "device_reduce.hpp"
|
||||
#include "device_reduce_multiblock.hpp"
|
||||
#include "device_reduce_common.hpp"
|
||||
#include "gridwise_softmax.hpp"
|
||||
#include "gridwise_set_buffer_value.hpp"
|
||||
#include "reduction_operator.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename InDataType,
|
||||
typename AccDataType,
|
||||
typename OutDataType,
|
||||
index_t Rank,
|
||||
index_t NumReduceDim,
|
||||
index_t BlockSize,
|
||||
index_t MThreadClusterSize,
|
||||
index_t KThreadClusterSize,
|
||||
index_t MThreadSliceSize,
|
||||
index_t KThreadSliceSize,
|
||||
index_t InSrcVectorDim,
|
||||
index_t InSrcVectorSize,
|
||||
index_t OutDstVectorSize>
|
||||
struct DeviceSoftmax : public BaseOperator
|
||||
{
|
||||
using PassThrough = tensor_operation::element_wise::PassThrough;
|
||||
|
||||
// Used for freeloading of some handy functions from DeviceReduceMultiBlock
|
||||
using Reduction = DeviceReduceMultiBlock<InDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
Rank,
|
||||
NumReduceDim,
|
||||
reduce::Add,
|
||||
PassThrough, // InElementwiseOperation
|
||||
PassThrough, // AccElementwiseOperation
|
||||
InMemoryDataOperationEnum::Set,
|
||||
false, // PropagateNan
|
||||
false, // OutputIndex
|
||||
false, // HaveIndexInputIfOutputIndex
|
||||
BlockSize,
|
||||
MThreadClusterSize,
|
||||
KThreadClusterSize,
|
||||
MThreadSliceSize,
|
||||
KThreadSliceSize,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1>; // OutDstVectorSize
|
||||
|
||||
using GridDesc_M_K = decltype(Reduction::MakeSrc2dDescriptor({1}, {1}, 1, 1));
|
||||
|
||||
using GridwiseReduce = GridwiseSoftmax_mk_to_mk<InDataType,
|
||||
OutDataType,
|
||||
AccDataType,
|
||||
GridDesc_M_K,
|
||||
BlockSize,
|
||||
MThreadClusterSize,
|
||||
KThreadClusterSize,
|
||||
MThreadSliceSize,
|
||||
KThreadSliceSize,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
OutDstVectorSize>;
|
||||
|
||||
struct Argument : public Reduction::Argument
|
||||
{
|
||||
Argument(const std::vector<index_t> inLengths,
|
||||
const std::vector<index_t> inStrides,
|
||||
const std::vector<index_t> reduceDims,
|
||||
AccDataType alpha,
|
||||
AccDataType beta,
|
||||
const InDataType* in_dev,
|
||||
OutDataType* out_dev)
|
||||
: Reduction::Argument(inLengths,
|
||||
inStrides,
|
||||
{},
|
||||
{},
|
||||
reduceDims,
|
||||
0.0f, // alpha
|
||||
0.0f, // beta
|
||||
in_dev,
|
||||
nullptr,
|
||||
out_dev,
|
||||
nullptr,
|
||||
PassThrough{},
|
||||
PassThrough{}),
|
||||
// FIXME: The base class DeviceReduceMultiBlock::Argument only supports alpha/beta of
|
||||
// float32 precision. Make it support any data type so the fields can be removed.
|
||||
alpha_(alpha),
|
||||
beta_(beta)
|
||||
{
|
||||
// std::cout << "blkGroupSize= " << this->blkGroupSize
|
||||
// << ", numBlockTileIteration= " << this->numBlockTileIteration
|
||||
// << ", gridSize=" << this->gridSize
|
||||
// << ", invariant_total_length=" << this->invariant_total_length <<
|
||||
// std::endl;
|
||||
}
|
||||
|
||||
AccDataType alpha_;
|
||||
AccDataType beta_;
|
||||
};
|
||||
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
const auto in_grid_desc_m_k = Reduction::MakeSrc2dDescriptor(
|
||||
arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.numBlockTileIteration);
|
||||
const auto out_grid_desc_m_k = Reduction::MakeSrc2dDescriptor(
|
||||
arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.numBlockTileIteration);
|
||||
|
||||
const auto kernel_main =
|
||||
kernel_softmax<GridwiseReduce, InDataType, OutDataType, AccDataType, GridDesc_M_K>;
|
||||
|
||||
float avg_time = 0;
|
||||
|
||||
avg_time += launch_and_time_kernel(stream_config,
|
||||
kernel_main,
|
||||
dim3(arg.gridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
in_grid_desc_m_k,
|
||||
out_grid_desc_m_k,
|
||||
arg.blkGroupSize,
|
||||
arg.numBlockTileIteration,
|
||||
arg.alpha_,
|
||||
arg.in_dev_,
|
||||
arg.beta_,
|
||||
arg.out_dev_);
|
||||
|
||||
return (avg_time);
|
||||
};
|
||||
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
};
|
||||
};
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
const Argument* p_arg_ = dynamic_cast<const Argument*>(p_arg);
|
||||
|
||||
if(!Reduction::IsSupportedArgument(p_arg_))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if(p_arg_->inLengths_[Rank - 1] % OutDstVectorSize != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(const std::vector<index_t> inLengths,
|
||||
const std::vector<index_t> inStrides,
|
||||
const std::vector<int> reduceDims,
|
||||
AccDataType alpha,
|
||||
AccDataType beta,
|
||||
const void* in_dev,
|
||||
void* out_dev)
|
||||
{
|
||||
return std::make_unique<Argument>(inLengths,
|
||||
inStrides,
|
||||
reduceDims,
|
||||
alpha,
|
||||
beta,
|
||||
static_cast<const InDataType*>(in_dev),
|
||||
static_cast<OutDataType*>(out_dev));
|
||||
};
|
||||
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() { return std::make_unique<Invoker>(); };
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceReduceSoftmax<" << BlockSize << ",";
|
||||
str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
|
||||
str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
|
||||
str << "InSrcVectorDim_" << InSrcVectorDim << "_InSrcVectorSize_" << InSrcVectorSize << "_OutDstVectorSize_" << OutDstVectorSize << ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
#endif // DEVICE_SOFTMAX_HPP
|
||||
407
include/ck/tensor_operation/gpu/grid/gridwise_softmax.hpp
Normal file
407
include/ck/tensor_operation/gpu/grid/gridwise_softmax.hpp
Normal file
@@ -0,0 +1,407 @@
|
||||
/*******************************************************************************
|
||||
*
|
||||
* MIT License
|
||||
*
|
||||
* Copyright (c) 2022 Advanced Micro Devices, Inc.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in all
|
||||
* copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
* SOFTWARE.
|
||||
*
|
||||
*******************************************************************************/
|
||||
#ifndef GRIDWISE_SOFTMAX_HPP
|
||||
#define GRIDWISE_SOFTMAX_HPP
|
||||
|
||||
#include "reduction_common.hpp"
|
||||
#include "reduction_operator.hpp"
|
||||
#include "reduction_functions_accumulate.hpp"
|
||||
#include "reduction_functions_blockwise.hpp"
|
||||
#include "reduction_functions_threadwise.hpp"
|
||||
|
||||
#include "threadwise_tensor_slice_transfer.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename GridwiseReduction,
|
||||
typename InDataType,
|
||||
typename OutDataType,
|
||||
typename AccDataType,
|
||||
typename GridDesc_M_K>
|
||||
__global__ void kernel_softmax(const GridDesc_M_K in_grid_desc_m_k,
|
||||
const GridDesc_M_K out_grid_desc_m_k,
|
||||
index_t block_group_size,
|
||||
index_t num_k_block_tile_iteration,
|
||||
AccDataType alpha,
|
||||
const InDataType* const __restrict__ p_in_value_global,
|
||||
AccDataType beta,
|
||||
OutDataType* const __restrict__ p_out_value_global)
|
||||
{
|
||||
GridwiseReduction::Run(in_grid_desc_m_k,
|
||||
out_grid_desc_m_k,
|
||||
block_group_size,
|
||||
num_k_block_tile_iteration,
|
||||
alpha,
|
||||
p_in_value_global,
|
||||
beta,
|
||||
p_out_value_global);
|
||||
};
|
||||
|
||||
template <typename InDataType,
|
||||
typename OutDataType,
|
||||
typename AccDataType,
|
||||
typename GridDesc_M_K,
|
||||
index_t BlockSize,
|
||||
index_t MThreadClusterSize,
|
||||
index_t KThreadClusterSize,
|
||||
index_t MThreadSliceSize,
|
||||
index_t KThreadSliceSize,
|
||||
index_t InSrcVectorDim,
|
||||
index_t InSrcVectorSize,
|
||||
index_t OutDstVectorSize>
|
||||
struct GridwiseSoftmax_mk_to_mk
|
||||
{
|
||||
static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
|
||||
(InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) &&
|
||||
(KThreadSliceSize % OutDstVectorSize == 0),
|
||||
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
|
||||
|
||||
static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0);
|
||||
|
||||
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
|
||||
|
||||
using ThreadBufferDimAccessOrder =
|
||||
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
|
||||
|
||||
using ThreadClusterArrangeOrder =
|
||||
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
|
||||
|
||||
static constexpr auto thread_cluster_desc =
|
||||
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
|
||||
|
||||
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})));
|
||||
using ThreadReduceDstDesc_M =
|
||||
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
|
||||
|
||||
using BlockwiseMaxReduce = PartitionedBlockwiseReduction<AccDataType,
|
||||
BlockSize,
|
||||
ThreadClusterLengths_M_K,
|
||||
ThreadClusterArrangeOrder,
|
||||
reduce::Max,
|
||||
false>; // PropagateNan
|
||||
|
||||
using ThreadwiseMaxReduce = ThreadwiseReduction<AccDataType,
|
||||
ThreadReduceSrcDesc_M_K,
|
||||
ThreadReduceDstDesc_M,
|
||||
reduce::Max,
|
||||
false>; // PropagateNan
|
||||
|
||||
using PassThroughOp = tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
|
||||
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
|
||||
|
||||
__device__ static void Run(const GridDesc_M_K& in_grid_desc_m_k,
|
||||
const GridDesc_M_K& out_grid_desc_m_k,
|
||||
index_t block_group_size,
|
||||
index_t num_k_block_tile_iteration,
|
||||
AccDataType alpha,
|
||||
const InDataType* const __restrict__ p_in_value_global,
|
||||
AccDataType beta,
|
||||
OutDataType* const __restrict__ p_out_value_global)
|
||||
{
|
||||
// LDS
|
||||
__shared__ AccDataType p_reduce_work_buffer[BlockSize];
|
||||
|
||||
auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_out_value_global, out_grid_desc_m_k.GetElementSpaceSize());
|
||||
|
||||
auto reduce_work_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
|
||||
in_thread_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
|
||||
out_thread_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> max_value_buf;
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
max_value_buf(I) = reduce::Max::template GetIdentityValue<AccDataType>();
|
||||
});
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
accu_value_buf(I) = reduce::Add::template GetIdentityValue<AccDataType>();
|
||||
});
|
||||
|
||||
const index_t thread_local_id = get_thread_local_1d_id();
|
||||
const index_t block_global_id = get_block_1d_id();
|
||||
const index_t blkgroup_id = block_global_id / block_group_size;
|
||||
const index_t block_local_id = block_global_id % block_group_size;
|
||||
|
||||
const auto thread_cluster_idx =
|
||||
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
|
||||
|
||||
const auto thread_m_cluster_id = thread_cluster_idx[I0];
|
||||
const auto thread_k_cluster_id = thread_cluster_idx[I1];
|
||||
|
||||
const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration;
|
||||
|
||||
using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
|
||||
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
|
||||
|
||||
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType,
|
||||
AccDataType,
|
||||
GridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
ThreadBufferDimAccessOrder,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(
|
||||
in_grid_desc_m_k,
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
|
||||
block_local_id * reduceSizePerBlock +
|
||||
thread_k_cluster_id * KThreadSliceSize));
|
||||
|
||||
auto threadwise_dst_load = ThreadwiseTensorSliceTransfer_v2<OutDataType,
|
||||
AccDataType,
|
||||
GridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
ThreadBufferDimAccessOrder,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(
|
||||
out_grid_desc_m_k,
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
|
||||
block_local_id * reduceSizePerBlock +
|
||||
thread_k_cluster_id * KThreadSliceSize));
|
||||
|
||||
auto threadwise_dst_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
|
||||
OutDataType,
|
||||
decltype(thread_buffer_desc),
|
||||
GridDesc_M_K,
|
||||
PassThroughOp,
|
||||
ThreadBufferLengths,
|
||||
ThreadBufferDimAccessOrder,
|
||||
InSrcVectorDim,
|
||||
OutDstVectorSize,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>(
|
||||
out_grid_desc_m_k,
|
||||
make_multi_index(
|
||||
blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
|
||||
block_local_id * reduceSizePerBlock + thread_k_cluster_id * KThreadSliceSize),
|
||||
PassThroughOp{});
|
||||
|
||||
constexpr auto in_thread_copy_fwd_step = make_multi_index(0, K_BlockTileSize);
|
||||
constexpr auto in_thread_copy_bwd_step = make_multi_index(0, -K_BlockTileSize);
|
||||
|
||||
///
|
||||
/// max(x)
|
||||
///
|
||||
const auto in_global_val_buf_oob_non_zero = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_in_value_global,
|
||||
in_grid_desc_m_k.GetElementSpaceSize(),
|
||||
reduce::Max::template GetIdentityValue<InDataType>());
|
||||
index_t reducedTiles = 0;
|
||||
do
|
||||
{
|
||||
threadwise_src_load.Run(in_grid_desc_m_k,
|
||||
in_global_val_buf_oob_non_zero,
|
||||
thread_buffer_desc,
|
||||
make_tuple(I0, I0),
|
||||
in_thread_buf);
|
||||
|
||||
ThreadwiseMaxReduce::Reduce(in_thread_buf, max_value_buf);
|
||||
|
||||
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_fwd_step);
|
||||
|
||||
reducedTiles++;
|
||||
} while(reducedTiles < num_k_block_tile_iteration);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}(
|
||||
[&](auto I) { BlockwiseMaxReduce::Reduce(reduce_work_buf, max_value_buf(I)); });
|
||||
|
||||
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_bwd_step);
|
||||
|
||||
///
|
||||
/// sum(exp(x - max(x)))
|
||||
///
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
accu_value_buf(I) = reduce::Add::template GetIdentityValue<AccDataType>();
|
||||
});
|
||||
|
||||
// Normally, 0 as invalid element value is adequate since 0 makes no contribution to
|
||||
// accumulated result. However, in stable softmax, all values 0s or not are subtracted by
|
||||
// another value_max. As numbers become non-zero, effectively it allows invalid values to
|
||||
// slip through and contribute to the accumulated result.
|
||||
//
|
||||
// The trick here is leveraging the fact that many math functions (add, sub, exp, ...)
|
||||
// propagate NaNs when operands have NaNs involved. By initialiing invalid element value
|
||||
// with NaN, an invalid value doing math manipulations is still NaN, which in turn can still
|
||||
// be identified as an invalid value. We can then discard the invalid values which
|
||||
// originally failed the bound check during accumulation. This allows to ignore values that
|
||||
// failed bound check even after multiple math manipulations.
|
||||
const auto in_global_val_buf_oob_nan =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global,
|
||||
in_grid_desc_m_k.GetElementSpaceSize(),
|
||||
NumericLimits<InDataType>::QuietNaN());
|
||||
|
||||
using BlockwiseSumReduce = PartitionedBlockwiseReduction<
|
||||
AccDataType,
|
||||
BlockSize,
|
||||
ThreadClusterLengths_M_K,
|
||||
ThreadClusterArrangeOrder,
|
||||
reduce::Add,
|
||||
false, // ignored
|
||||
detail::AccumulateWithNanIgnore<reduce::Add, AccDataType>>;
|
||||
|
||||
using ThreadwiseSumReduce =
|
||||
ThreadwiseReduction<AccDataType,
|
||||
ThreadReduceSrcDesc_M_K,
|
||||
ThreadReduceDstDesc_M,
|
||||
reduce::Add,
|
||||
false, // ignored
|
||||
detail::AccumulateWithNanIgnore<reduce::Add, AccDataType>>;
|
||||
|
||||
reducedTiles = 0;
|
||||
do
|
||||
{
|
||||
threadwise_src_load.Run(in_grid_desc_m_k,
|
||||
in_global_val_buf_oob_nan,
|
||||
thread_buffer_desc,
|
||||
make_tuple(I0, I0),
|
||||
in_thread_buf);
|
||||
|
||||
// do element-wise pre-reduction operation
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
|
||||
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
|
||||
in_thread_buf(Number<offset>{}) =
|
||||
math::exp(in_thread_buf(Number<offset>{}) - max_value_buf(iM));
|
||||
});
|
||||
});
|
||||
|
||||
ThreadwiseSumReduce::Reduce(in_thread_buf, accu_value_buf);
|
||||
|
||||
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_bwd_step);
|
||||
|
||||
reducedTiles++;
|
||||
} while(reducedTiles < num_k_block_tile_iteration);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
BlockwiseSumReduce::Reduce(reduce_work_buf, accu_value_buf(I));
|
||||
// block_sync_lds();
|
||||
});
|
||||
|
||||
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_fwd_step);
|
||||
|
||||
///
|
||||
/// softmax
|
||||
///
|
||||
reducedTiles = 0;
|
||||
if(float_equal_zero{}(beta))
|
||||
{
|
||||
do
|
||||
{
|
||||
threadwise_src_load.Run(in_grid_desc_m_k,
|
||||
in_global_val_buf_oob_nan,
|
||||
thread_buffer_desc,
|
||||
make_tuple(I0, I0),
|
||||
in_thread_buf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
// out = alpha * exp(x - max(x)) / sum(exp(x - max(x)))
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
|
||||
constexpr auto offset =
|
||||
thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
|
||||
out_thread_buf(Number<offset>{}) =
|
||||
alpha * math::exp(in_thread_buf(Number<offset>{}) - max_value_buf(iM)) /
|
||||
accu_value_buf(iM);
|
||||
});
|
||||
});
|
||||
|
||||
threadwise_dst_store.Run(thread_buffer_desc,
|
||||
make_tuple(I0, I0),
|
||||
out_thread_buf,
|
||||
out_grid_desc_m_k,
|
||||
out_global_val_buf);
|
||||
|
||||
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_fwd_step);
|
||||
threadwise_dst_store.MoveDstSliceWindow(out_grid_desc_m_k, in_thread_copy_fwd_step);
|
||||
|
||||
reducedTiles++;
|
||||
} while(reducedTiles < num_k_block_tile_iteration);
|
||||
}
|
||||
else
|
||||
{
|
||||
do
|
||||
{
|
||||
threadwise_src_load.Run(in_grid_desc_m_k,
|
||||
in_global_val_buf_oob_nan,
|
||||
thread_buffer_desc,
|
||||
make_tuple(I0, I0),
|
||||
in_thread_buf);
|
||||
threadwise_dst_load.Run(out_grid_desc_m_k,
|
||||
out_global_val_buf,
|
||||
thread_buffer_desc,
|
||||
make_tuple(I0, I0),
|
||||
out_thread_buf);
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
// out = alpha * exp(x - max(x)) / sum(exp(x - max(x))) + beta * prior_out
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
|
||||
constexpr auto offset =
|
||||
thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
|
||||
out_thread_buf(Number<offset>{}) =
|
||||
alpha * math::exp(in_thread_buf(Number<offset>{}) - max_value_buf(iM)) /
|
||||
accu_value_buf(iM) +
|
||||
beta * out_thread_buf(Number<offset>{});
|
||||
});
|
||||
});
|
||||
|
||||
threadwise_dst_store.Run(thread_buffer_desc,
|
||||
make_tuple(I0, I0),
|
||||
out_thread_buf,
|
||||
out_grid_desc_m_k,
|
||||
out_global_val_buf);
|
||||
|
||||
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_fwd_step);
|
||||
threadwise_dst_store.MoveDstSliceWindow(out_grid_desc_m_k, in_thread_copy_fwd_step);
|
||||
threadwise_dst_load.MoveSrcSliceWindow(out_grid_desc_m_k, in_thread_copy_fwd_step);
|
||||
|
||||
reducedTiles++;
|
||||
} while(reducedTiles < num_k_block_tile_iteration);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif // GRIDWISE_SOFTMAX_HPP
|
||||
@@ -39,7 +39,9 @@ template <typename AccDataType,
|
||||
typename SrcThreadDesc_M_K,
|
||||
typename DstThreadDesc_M,
|
||||
typename OpReduce,
|
||||
bool PropagateNan>
|
||||
bool PropagateNan,
|
||||
typename Accumulation =
|
||||
detail::AccumulateWithNanCheck<PropagateNan, OpReduce, AccDataType>>
|
||||
struct ThreadwiseReduction
|
||||
{
|
||||
static constexpr auto src_thread_desc_m_k = SrcThreadDesc_M_K{};
|
||||
@@ -51,8 +53,6 @@ struct ThreadwiseReduction
|
||||
|
||||
static_assert(src_length_m == dst_length_m, "lengths of source and dst buffer must match!");
|
||||
|
||||
using Accumulation = detail::AccumulateWithNanCheck<PropagateNan, OpReduce, AccDataType>;
|
||||
|
||||
template <typename SrcBufferType, typename DstBufferType>
|
||||
__device__ static void Reduce(const SrcBufferType& src_buf, DstBufferType& dst_buf)
|
||||
{
|
||||
@@ -73,12 +73,15 @@ struct ThreadwiseReduction
|
||||
// 2) DstDesc is known at compile-time
|
||||
// 3) SrcBuffer is static buffer
|
||||
// 4) DstBuffer is static buffer
|
||||
template <typename AccDataType,
|
||||
typename IndexDataType,
|
||||
typename SrcThreadDesc_M_K,
|
||||
typename DstThreadDesc_M,
|
||||
typename OpReduce,
|
||||
bool PropagateNan>
|
||||
template <
|
||||
typename AccDataType,
|
||||
typename IndexDataType,
|
||||
typename SrcThreadDesc_M_K,
|
||||
typename DstThreadDesc_M,
|
||||
typename OpReduce,
|
||||
bool PropagateNan,
|
||||
typename Accumulation =
|
||||
detail::AccumulateWithIndexAndNanCheck<PropagateNan, OpReduce, AccDataType, IndexDataType>>
|
||||
struct ThreadwiseReductionWithIndex
|
||||
{
|
||||
static constexpr auto src_thread_desc_m_k = SrcThreadDesc_M_K{};
|
||||
@@ -90,9 +93,6 @@ struct ThreadwiseReductionWithIndex
|
||||
|
||||
static_assert(src_length_m == dst_length_m, "lengths of source and dst buffer must match!");
|
||||
|
||||
using Accumulation =
|
||||
detail::AccumulateWithIndexAndNanCheck<PropagateNan, OpReduce, AccDataType, IndexDataType>;
|
||||
|
||||
template <typename SrcValueBufferType,
|
||||
typename SrcIndexBufferType,
|
||||
typename DstValueBufferType,
|
||||
|
||||
@@ -1001,6 +1001,11 @@ struct NumericLimits
|
||||
__host__ __device__ static constexpr T Max() { return std::numeric_limits<T>::max(); }
|
||||
|
||||
__host__ __device__ static constexpr T Lowest() { return std::numeric_limits<T>::lowest(); }
|
||||
|
||||
__host__ __device__ static constexpr T QuietNaN()
|
||||
{
|
||||
return std::numeric_limits<T>::quiet_NaN();
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
@@ -1009,12 +1014,15 @@ struct NumericLimits<half_t>
|
||||
static constexpr unsigned short binary_min = 0x0400;
|
||||
static constexpr unsigned short binary_max = 0x7BFF;
|
||||
static constexpr unsigned short binary_lowest = 0xFBFF;
|
||||
static constexpr unsigned short binary_qnan = 0x7FFF;
|
||||
|
||||
__host__ __device__ static constexpr half_t Min() { return bit_cast<half_t>(binary_min); }
|
||||
|
||||
__host__ __device__ static constexpr half_t Max() { return bit_cast<half_t>(binary_max); }
|
||||
|
||||
__host__ __device__ static constexpr half_t Lowest() { return bit_cast<half_t>(binary_lowest); }
|
||||
|
||||
__host__ __device__ static constexpr half_t QuietNaN() { return bit_cast<half_t>(binary_qnan); }
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -142,6 +142,22 @@ __host__ __device__ constexpr auto min(X x, Ys... ys)
|
||||
return min(x, min(ys...));
|
||||
}
|
||||
|
||||
// disallow implicit type casting
|
||||
template <typename T>
|
||||
__device__ T exp(T x);
|
||||
|
||||
template <>
|
||||
__device__ float exp<float>(float x)
|
||||
{
|
||||
return __expf(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ double exp<double>(double x)
|
||||
{
|
||||
return exp(x);
|
||||
}
|
||||
|
||||
// greatest common divisor, aka highest common factor
|
||||
__host__ __device__ constexpr index_t gcd(index_t x, index_t y)
|
||||
{
|
||||
|
||||
@@ -35,9 +35,27 @@
|
||||
namespace ck {
|
||||
namespace detail {
|
||||
|
||||
// Check for NaN; guarantee NaNs are NOT propagated to result (i.e., ignore NaNs)
|
||||
template <typename ReduceOperation, typename AccDataType>
|
||||
struct AccumulateWithNanIgnore
|
||||
{
|
||||
__device__ static inline void Calculate(AccDataType& accuVal, AccDataType currVal)
|
||||
{
|
||||
if(!isnan(currVal))
|
||||
{
|
||||
ReduceOperation{}(accuVal, currVal);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
template <bool PropagateNan, typename ReduceOperation, typename AccDataType>
|
||||
struct AccumulateWithNanCheck;
|
||||
|
||||
// Does not check for NaN; does not guarantee NaNs be propagated to result
|
||||
// e.g., given that max(a, b) = a > b ? a : b
|
||||
// then max(NaN, 1) returns 1
|
||||
// max(1, NaN) returns NaN
|
||||
// since any comparison involving NaNs returns false
|
||||
template <typename ReduceOperation, typename AccDataType>
|
||||
struct AccumulateWithNanCheck<false, ReduceOperation, AccDataType>
|
||||
{
|
||||
@@ -48,6 +66,7 @@ struct AccumulateWithNanCheck<false, ReduceOperation, AccDataType>
|
||||
};
|
||||
};
|
||||
|
||||
// Check for NaN; guarantees NaNs be propagated to result
|
||||
template <typename ReduceOperation, typename AccDataType>
|
||||
struct AccumulateWithNanCheck<true, ReduceOperation, AccDataType>
|
||||
{
|
||||
|
||||
@@ -107,6 +107,11 @@ struct HostTensorDescriptor
|
||||
return std::inner_product(iss.begin(), iss.end(), mStrides.begin(), std::size_t{0});
|
||||
}
|
||||
|
||||
std::size_t GetOffsetFromMultiIndex(std::vector<std::size_t> iss) const
|
||||
{
|
||||
return std::inner_product(iss.begin(), iss.end(), mStrides.begin(), std::size_t{0});
|
||||
}
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, const HostTensorDescriptor& desc);
|
||||
|
||||
private:
|
||||
@@ -212,6 +217,54 @@ struct Tensor
|
||||
|
||||
Tensor(const HostTensorDescriptor& desc) : mDesc(desc), mData(mDesc.GetElementSpace()) {}
|
||||
|
||||
Tensor(const Tensor& other) : mDesc(other.mDesc), mData(other.mData) {}
|
||||
|
||||
template <typename F>
|
||||
void ForEach_impl(F&& f, std::vector<size_t>& idx, size_t rank)
|
||||
{
|
||||
if(rank == mDesc.GetNumOfDimension())
|
||||
{
|
||||
f(*this, idx);
|
||||
return;
|
||||
}
|
||||
// else
|
||||
for(size_t i = 0; i < mDesc.GetLengths()[rank]; i++)
|
||||
{
|
||||
idx[rank] = i;
|
||||
ForEach_impl(std::forward<F>(f), idx, rank + 1);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename F>
|
||||
void ForEach(F&& f)
|
||||
{
|
||||
std::vector<size_t> idx(mDesc.GetNumOfDimension(), 0);
|
||||
ForEach_impl(std::forward<F>(f), idx, size_t(0));
|
||||
}
|
||||
|
||||
template <typename F>
|
||||
void ForEach_impl(const F&& f, std::vector<size_t>& idx, size_t rank) const
|
||||
{
|
||||
if(rank == mDesc.GetNumOfDimension())
|
||||
{
|
||||
f(*this, idx);
|
||||
return;
|
||||
}
|
||||
// else
|
||||
for(size_t i = 0; i < mDesc.GetLengths()[rank]; i++)
|
||||
{
|
||||
idx[rank] = i;
|
||||
ForEach_impl(std::forward<const F>(f), idx, rank + 1);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename F>
|
||||
void ForEach(const F&& f) const
|
||||
{
|
||||
std::vector<size_t> idx(mDesc.GetNumOfDimension(), 0);
|
||||
ForEach_impl(std::forward<const F>(f), idx, size_t(0));
|
||||
}
|
||||
|
||||
template <typename G>
|
||||
void GenerateTensorValue(G g, std::size_t num_thread = 1)
|
||||
{
|
||||
@@ -272,6 +325,16 @@ struct Tensor
|
||||
return mData[mDesc.GetOffsetFromMultiIndex(is...)];
|
||||
}
|
||||
|
||||
T& operator()(std::vector<std::size_t> idx)
|
||||
{
|
||||
return mData[mDesc.GetOffsetFromMultiIndex(idx)];
|
||||
}
|
||||
|
||||
const T& operator()(std::vector<std::size_t> idx) const
|
||||
{
|
||||
return mData[mDesc.GetOffsetFromMultiIndex(idx)];
|
||||
}
|
||||
|
||||
typename std::vector<T>::iterator begin() { return mData.begin(); }
|
||||
|
||||
typename std::vector<T>::iterator end() { return mData.end(); }
|
||||
@@ -285,7 +348,8 @@ struct Tensor
|
||||
};
|
||||
|
||||
template <typename X>
|
||||
HostTensorDescriptor::HostTensorDescriptor(const std::vector<X>& lens) : mLens(lens)
|
||||
HostTensorDescriptor::HostTensorDescriptor(const std::vector<X>& lens)
|
||||
: mLens(lens.begin(), lens.end())
|
||||
{
|
||||
this->CalculateStrides();
|
||||
}
|
||||
@@ -293,7 +357,7 @@ HostTensorDescriptor::HostTensorDescriptor(const std::vector<X>& lens) : mLens(l
|
||||
template <typename X, typename Y>
|
||||
HostTensorDescriptor::HostTensorDescriptor(const std::vector<X>& lens,
|
||||
const std::vector<Y>& strides)
|
||||
: mLens(lens), mStrides(strides)
|
||||
: mLens(lens.begin(), lens.end()), mStrides(strides.begin(), strides.end())
|
||||
{
|
||||
}
|
||||
|
||||
|
||||
@@ -18,12 +18,12 @@ struct GeneratorTensor_0
|
||||
template <typename T>
|
||||
struct GeneratorTensor_1
|
||||
{
|
||||
int value = 1;
|
||||
T value = 1;
|
||||
|
||||
template <typename... Is>
|
||||
T operator()(Is...)
|
||||
{
|
||||
return ck::type_convert<T>(value);
|
||||
return value;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -0,0 +1,162 @@
|
||||
#pragma once
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include "device_base.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "host_tensor_generator.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace host {
|
||||
|
||||
template <typename InDataType, typename OutDataType, typename AccDataType>
|
||||
struct ReferenceSoftmax : public device::BaseOperator
|
||||
{
|
||||
// Argument
|
||||
struct Argument : public device::BaseArgument
|
||||
{
|
||||
Argument(const Tensor<InDataType>& in,
|
||||
Tensor<OutDataType>& out,
|
||||
AccDataType alpha,
|
||||
AccDataType beta,
|
||||
const index_t rank,
|
||||
const std::vector<index_t> sm_reduce_dims)
|
||||
: in_(in), out_(out), alpha_(alpha), beta_(beta), sm_reduce_dims_(sm_reduce_dims)
|
||||
{
|
||||
// std::cout << "debug: scalar dims: ";
|
||||
for(int i = 0; i < rank; i++)
|
||||
{
|
||||
if(std::find(sm_reduce_dims.begin(), sm_reduce_dims.end(), i) ==
|
||||
sm_reduce_dims.end())
|
||||
{
|
||||
sm_scalar_dims_.push_back(i);
|
||||
// std::cout << i << ", ";
|
||||
}
|
||||
}
|
||||
// std::cout << std::endl;
|
||||
}
|
||||
|
||||
const Tensor<InDataType>& in_;
|
||||
Tensor<OutDataType>& out_;
|
||||
AccDataType alpha_;
|
||||
AccDataType beta_;
|
||||
index_t rank_;
|
||||
std::vector<index_t> sm_reduce_dims_;
|
||||
std::vector<index_t> sm_scalar_dims_; // dim after internal max/sum reduction
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public device::BaseInvoker
|
||||
{
|
||||
float Run(const Argument& arg)
|
||||
{
|
||||
std::vector<size_t> scalar_lengths;
|
||||
for(index_t dim : arg.sm_scalar_dims_)
|
||||
{
|
||||
scalar_lengths.push_back(arg.in_.mDesc.GetLengths()[dim]);
|
||||
}
|
||||
|
||||
Tensor<AccDataType> reduce_max(scalar_lengths);
|
||||
reduce_max.GenerateTensorValue(
|
||||
GeneratorTensor_1<AccDataType>{std::numeric_limits<AccDataType>::lowest()});
|
||||
Tensor<AccDataType> reduce_sum(scalar_lengths);
|
||||
reduce_sum.GenerateTensorValue(GeneratorTensor_1<AccDataType>{0});
|
||||
|
||||
auto to_sm_scalar_idx = [&](auto idx) {
|
||||
std::vector<size_t> sm_scalar_idx;
|
||||
for(index_t dim : arg.sm_scalar_dims_)
|
||||
{
|
||||
sm_scalar_idx.push_back(idx[dim]);
|
||||
}
|
||||
return sm_scalar_idx;
|
||||
};
|
||||
|
||||
arg.in_.ForEach([&](auto& self, auto idx) {
|
||||
reduce_max(to_sm_scalar_idx(idx)) = std::max(reduce_max(to_sm_scalar_idx(idx)),
|
||||
static_cast<AccDataType>(self(idx)));
|
||||
});
|
||||
|
||||
// LogRangeAsType<float>(std::cout << "reduce_max: ", reduce_max.mData, ",") <<
|
||||
// std::endl;
|
||||
|
||||
Tensor<AccDataType> in_stable(arg.in_.mDesc);
|
||||
in_stable.ForEach([&](auto& self, auto idx) {
|
||||
// numerator = exp(x - max(x))
|
||||
self(idx) = std::exp(static_cast<AccDataType>(arg.in_(idx)) -
|
||||
reduce_max(to_sm_scalar_idx(idx)));
|
||||
});
|
||||
|
||||
// LogRangeAsType<float>(std::cout << "in_stable: ", in_stable.mData, ",") << std::endl;
|
||||
|
||||
in_stable.ForEach([&](auto& self, auto idx) {
|
||||
// denominator = sum(exp(x - max(x)))
|
||||
reduce_sum(to_sm_scalar_idx(idx)) += self(idx);
|
||||
});
|
||||
|
||||
// LogRangeAsType<float>(std::cout << "reduce_sum: ", reduce_sum.mData, ",") <<
|
||||
// std::endl;
|
||||
|
||||
arg.out_.ForEach([&](auto& self, auto idx) {
|
||||
self(idx) = arg.alpha_ * in_stable(idx) / reduce_sum(to_sm_scalar_idx(idx)) +
|
||||
arg.beta_ * self(idx);
|
||||
});
|
||||
|
||||
// LogRangeAsType<float>(std::cout << "out: ", arg.out_.mData, ",") << std::endl;
|
||||
// reduction along reduce dims
|
||||
// LogRangeAsType<float>(std::cout << "reduce_max: ", reduce_max.mData, ",") <<
|
||||
// std::endl; LogRangeAsType<float>(std::cout << "reduce_sum: ", reduce_sum.mData, ",")
|
||||
// << std::endl;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
float Run(const device::BaseArgument* p_arg,
|
||||
const StreamConfig& /* stream_config */ = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
|
||||
|
||||
static auto MakeArgument(const Tensor<InDataType>& in,
|
||||
Tensor<OutDataType>& out,
|
||||
AccDataType alpha,
|
||||
AccDataType beta,
|
||||
const index_t rank,
|
||||
const std::vector<index_t> sm_reduce_dims)
|
||||
{
|
||||
return Argument{in, out, alpha, beta, rank, sm_reduce_dims};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
virtual std::unique_ptr<device::BaseInvoker> MakeInvokerPointer()
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "ReferenceSoftmax"
|
||||
<< std::endl;
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace host
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -65,4 +65,5 @@ add_subdirectory(reduce)
|
||||
add_subdirectory(conv2d_bwd_weight)
|
||||
add_subdirectory(convnd_bwd_data)
|
||||
add_subdirectory(block_to_ctile_map)
|
||||
add_subdirectory(softmax)
|
||||
# DONOT add client_app, that is tested via CI independently
|
||||
|
||||
8
test/softmax/CMakeLists.txt
Normal file
8
test/softmax/CMakeLists.txt
Normal file
@@ -0,0 +1,8 @@
|
||||
add_custom_target(test_softmax)
|
||||
|
||||
add_gtest_executable(test_softmax_fp32 test_softmax_fp32.cpp)
|
||||
add_gtest_executable(test_softmax_fp16 test_softmax_fp16.cpp)
|
||||
target_link_libraries(test_softmax_fp32 PRIVATE host_tensor)
|
||||
target_link_libraries(test_softmax_fp16 PRIVATE host_tensor)
|
||||
add_dependencies(test_softmax test_softmax_fp32)
|
||||
add_dependencies(test_softmax test_softmax_fp16)
|
||||
26
test/softmax/test_softmax_fp16.cpp
Normal file
26
test/softmax/test_softmax_fp16.cpp
Normal file
@@ -0,0 +1,26 @@
|
||||
#include "gtest/gtest.h"
|
||||
#include "test_softmax_util.hpp"
|
||||
|
||||
template <ck::index_t N>
|
||||
using I = ck::Number<N>;
|
||||
|
||||
template <typename Tuple>
|
||||
class TestSoftmaxFP16 : public ck::TestSoftmax<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes = ::testing::Types<
|
||||
// InDataType, AccDataType, OutDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, InSrcVectorDim, InSrcVectorSize, OutDstVectorSize>
|
||||
std::tuple<ck::half_t, float, ck::half_t, I<3>, I<1>, I<256>, I<8>, I<32>, I<1>, I<8>, I<1>, I<8>, I<8>>,
|
||||
std::tuple<ck::half_t, float, ck::half_t, I<3>, I<1>, I<256>, I<4>, I<64>, I<1>, I<8>, I<1>, I<8>, I<8>>,
|
||||
std::tuple<ck::half_t, float, ck::half_t, I<3>, I<1>, I<256>, I<2>, I<128>, I<1>, I<8>, I<1>, I<8>, I<8>>,
|
||||
std::tuple<ck::half_t, float, ck::half_t, I<3>, I<1>, I<256>, I<1>, I<256>, I<1>, I<8>, I<1>, I<8>, I<8>>,
|
||||
std::tuple<ck::half_t, float, ck::half_t, I<3>, I<2>, I<256>, I<8>, I<32>, I<1>, I<8>, I<1>, I<8>, I<8>>,
|
||||
std::tuple<ck::half_t, float, ck::half_t, I<3>, I<2>, I<256>, I<4>, I<64>, I<1>, I<8>, I<1>, I<8>, I<8>>,
|
||||
std::tuple<ck::half_t, float, ck::half_t, I<3>, I<2>, I<256>, I<2>, I<128>, I<1>, I<8>, I<1>, I<8>, I<8>>,
|
||||
std::tuple<ck::half_t, float, ck::half_t, I<3>, I<2>, I<256>, I<1>, I<256>, I<1>, I<8>, I<1>, I<8>, I<8>>
|
||||
>;
|
||||
// clang-format on
|
||||
TYPED_TEST_SUITE(TestSoftmaxFP16, KernelTypes);
|
||||
TYPED_TEST(TestSoftmaxFP16, Test_FP16) { this->Run(); }
|
||||
26
test/softmax/test_softmax_fp32.cpp
Normal file
26
test/softmax/test_softmax_fp32.cpp
Normal file
@@ -0,0 +1,26 @@
|
||||
#include "gtest/gtest.h"
|
||||
#include "test_softmax_util.hpp"
|
||||
|
||||
template <ck::index_t N>
|
||||
using I = ck::Number<N>;
|
||||
|
||||
template <typename Tuple>
|
||||
class TestSoftmaxFP32 : public ck::TestSoftmax<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes = ::testing::Types<
|
||||
// InDataType, AccDataType, OutDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, InSrcVectorDim, InSrcVectorSize, OutDstVectorSize>
|
||||
std::tuple<float, float, float, I<3>, I<1>, I<256>, I<8>, I<32>, I<1>, I<4>, I<1>, I<4>, I<4>>,
|
||||
std::tuple<float, float, float, I<3>, I<1>, I<256>, I<4>, I<64>, I<1>, I<4>, I<1>, I<4>, I<4>>,
|
||||
std::tuple<float, float, float, I<3>, I<1>, I<256>, I<2>, I<128>, I<1>, I<4>, I<1>, I<4>, I<4>>,
|
||||
std::tuple<float, float, float, I<3>, I<1>, I<256>, I<1>, I<256>, I<1>, I<4>, I<1>, I<4>, I<4>>,
|
||||
std::tuple<float, float, float, I<3>, I<2>, I<256>, I<8>, I<32>, I<1>, I<4>, I<1>, I<4>, I<4>>,
|
||||
std::tuple<float, float, float, I<3>, I<2>, I<256>, I<4>, I<64>, I<1>, I<4>, I<1>, I<4>, I<4>>,
|
||||
std::tuple<float, float, float, I<3>, I<2>, I<256>, I<2>, I<128>, I<1>, I<4>, I<1>, I<4>, I<4>>,
|
||||
std::tuple<float, float, float, I<3>, I<2>, I<256>, I<1>, I<256>, I<1>, I<4>, I<1>, I<4>, I<4>>
|
||||
>;
|
||||
// clang-format on
|
||||
TYPED_TEST_SUITE(TestSoftmaxFP32, KernelTypes);
|
||||
TYPED_TEST(TestSoftmaxFP32, Test_FP32) { this->Run(); }
|
||||
113
test/softmax/test_softmax_util.hpp
Normal file
113
test/softmax/test_softmax_util.hpp
Normal file
@@ -0,0 +1,113 @@
|
||||
#include <vector>
|
||||
#include <iostream>
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "config.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "check_err.hpp"
|
||||
#include "number.hpp"
|
||||
#include "reference_softmax.hpp"
|
||||
#include "device_softmax.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename Tuple>
|
||||
class TestSoftmax : public ::testing::Test
|
||||
{
|
||||
protected:
|
||||
using InDataType = std::tuple_element_t<0, Tuple>;
|
||||
using AccDataType = std::tuple_element_t<1, Tuple>;
|
||||
using OutDataType = std::tuple_element_t<2, Tuple>;
|
||||
static constexpr index_t Rank = std::tuple_element_t<3, Tuple>{}.value;
|
||||
static constexpr index_t NumReduceDim = std::tuple_element_t<4, Tuple>{}.value;
|
||||
static constexpr index_t BlockSize = std::tuple_element_t<5, Tuple>{}.value;
|
||||
static constexpr index_t MThreadClusterSize = std::tuple_element_t<6, Tuple>{}.value;
|
||||
static constexpr index_t KThreadClusterSize = std::tuple_element_t<7, Tuple>{}.value;
|
||||
static constexpr index_t MThreadSliceSize = std::tuple_element_t<8, Tuple>{}.value;
|
||||
static constexpr index_t KThreadSliceSize = std::tuple_element_t<9, Tuple>{}.value;
|
||||
static constexpr index_t InSrcVectorDim = std::tuple_element_t<10, Tuple>{}.value;
|
||||
static constexpr index_t InSrcVectorSize = std::tuple_element_t<11, Tuple>{}.value;
|
||||
static constexpr index_t OutDstVectorSize = std::tuple_element_t<12, Tuple>{}.value;
|
||||
|
||||
using ReferenceInstance =
|
||||
tensor_operation::host::ReferenceSoftmax<InDataType, OutDataType, AccDataType>;
|
||||
|
||||
using DeviceInstance = tensor_operation::device::DeviceSoftmax<InDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
Rank,
|
||||
NumReduceDim,
|
||||
BlockSize,
|
||||
MThreadClusterSize,
|
||||
KThreadClusterSize,
|
||||
MThreadSliceSize,
|
||||
KThreadSliceSize,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
OutDstVectorSize>;
|
||||
|
||||
TestSoftmax() : ref_instance_invoker_(ReferenceInstance{}.MakeInvoker()) {}
|
||||
|
||||
void RunSingle(std::vector<index_t> in_length, AccDataType alpha, AccDataType beta)
|
||||
{
|
||||
std::vector<index_t> reduce_dims(NumReduceDim);
|
||||
std::iota(reduce_dims.begin(), reduce_dims.end(), Rank - NumReduceDim);
|
||||
|
||||
Tensor<InDataType> in(in_length);
|
||||
Tensor<OutDataType> out(in_length);
|
||||
|
||||
in.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5});
|
||||
out.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-5, 5});
|
||||
|
||||
Tensor<OutDataType> out_ref(out);
|
||||
|
||||
DeviceMem in_dev(sizeof(InDataType) * in.mDesc.GetElementSpace());
|
||||
DeviceMem out_dev(sizeof(OutDataType) * out.mDesc.GetElementSpace());
|
||||
in_dev.ToDevice(in.mData.data());
|
||||
out_dev.ToDevice(out.mData.data());
|
||||
|
||||
std::vector<index_t> i_in_lengths(in.mDesc.GetLengths().begin(),
|
||||
in.mDesc.GetLengths().end());
|
||||
std::vector<index_t> i_in_strides(in.mDesc.GetStrides().begin(),
|
||||
in.mDesc.GetStrides().end());
|
||||
|
||||
auto device_instance = DeviceInstance{};
|
||||
auto argument_ptr = device_instance.MakeArgumentPointer(i_in_lengths,
|
||||
i_in_strides,
|
||||
reduce_dims,
|
||||
alpha,
|
||||
beta,
|
||||
in_dev.GetDeviceBuffer(),
|
||||
out_dev.GetDeviceBuffer());
|
||||
|
||||
if(!device_instance.IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
FAIL() << "Unsupported argument";
|
||||
}
|
||||
|
||||
auto invoker_ptr = device_instance.MakeInvokerPointer();
|
||||
invoker_ptr->Run(argument_ptr.get());
|
||||
|
||||
ref_instance_invoker_.Run({in, out_ref, alpha, beta, Rank, reduce_dims});
|
||||
|
||||
out_dev.FromDevice(out.mData.data());
|
||||
EXPECT_TRUE(ck::utils::check_err(out.mData, out_ref.mData));
|
||||
}
|
||||
|
||||
void Run()
|
||||
{
|
||||
for(auto in_length : this->in_lengths_)
|
||||
{
|
||||
for(auto scale : this->scales_)
|
||||
{
|
||||
this->RunSingle(in_length, std::get<0>(scale), std::get<1>(scale));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::vector<index_t>> in_lengths_ = {{1, 8, 128}, {2, 128, 1024}, {3, 9, 1032}};
|
||||
std::vector<std::tuple<AccDataType, AccDataType>> scales_ = {{1, 0}, {2, 2}, {0, 1}};
|
||||
|
||||
typename ReferenceInstance::Invoker ref_instance_invoker_;
|
||||
};
|
||||
} // namespace ck
|
||||
Reference in New Issue
Block a user