mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 01:10:17 +00:00
Use double for all scaling values and float-point constant values at the Device Op API (#557)
* Use double as alpha/beta values type in reduce device op api * Use double as alpha/beta values type in softmax device op api * Use double as alpha/beta values type in multiple-reduce device op api * Use double as epsilon value type in normalization/elementwise-normalization device op api
This commit is contained in:
@@ -56,8 +56,8 @@ struct ReferenceReduce : public device::DeviceReduce<InDataType,
|
||||
const std::array<index_t, NumDstDim> outLengths,
|
||||
const std::array<index_t, NumDstDim> outStrides,
|
||||
const std::array<int, NumReduceDim> reduceDims,
|
||||
float alpha,
|
||||
float beta,
|
||||
double alpha,
|
||||
double beta,
|
||||
const InDataType* in_host,
|
||||
OutDataType* out_host,
|
||||
IndexDataType* out_index_host,
|
||||
@@ -388,8 +388,8 @@ struct ReferenceReduce : public device::DeviceReduce<InDataType,
|
||||
const std::array<index_t, NumDstDim> outLengths,
|
||||
const std::array<index_t, NumDstDim> outStrides,
|
||||
const std::array<int, NumReduceDim> reduceDims,
|
||||
float alpha,
|
||||
float beta,
|
||||
double alpha,
|
||||
double beta,
|
||||
const void* in_host,
|
||||
const void* in_index_host,
|
||||
void* out_host,
|
||||
|
||||
@@ -24,11 +24,14 @@ struct ReferenceSoftmax : public device::BaseOperator
|
||||
{
|
||||
Argument(const Tensor<InDataType>& in,
|
||||
Tensor<OutDataType>& out,
|
||||
AccDataType alpha,
|
||||
AccDataType beta,
|
||||
double alpha,
|
||||
double beta,
|
||||
const std::vector<index_t> sm_reduce_dims)
|
||||
: in_(in), out_(out), alpha_(alpha), beta_(beta), sm_reduce_dims_(sm_reduce_dims)
|
||||
: in_(in), out_(out), sm_reduce_dims_(sm_reduce_dims)
|
||||
{
|
||||
alpha_ = static_cast<AccDataType>(alpha);
|
||||
beta_ = static_cast<AccDataType>(beta);
|
||||
|
||||
// std::cout << "debug: scalar dims: ";
|
||||
for(size_t i = 0; i < in.mDesc.GetNumOfDimension(); i++)
|
||||
{
|
||||
@@ -143,8 +146,8 @@ struct ReferenceSoftmax : public device::BaseOperator
|
||||
|
||||
static auto MakeArgument(const Tensor<InDataType>& in,
|
||||
Tensor<OutDataType>& out,
|
||||
AccDataType alpha,
|
||||
AccDataType beta,
|
||||
double alpha,
|
||||
double beta,
|
||||
const std::vector<index_t> sm_reduce_dims)
|
||||
{
|
||||
return Argument{in, out, alpha, beta, sm_reduce_dims};
|
||||
|
||||
Reference in New Issue
Block a user