mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-20 04:49:54 +00:00
Use double for all scaling values and float-point constant values at the Device Op API (#557)
* Use double as alpha/beta values type in reduce device op api
* Use double as alpha/beta values type in softmax device op api
* Use double as alpha/beta values type in multiple-reduce device op api
* Use double as epsilon value type in normalization/elementwise-normalization device op api
[ROCm/composable_kernel commit: 52abc2f371]
This commit is contained in:
@@ -332,8 +332,8 @@ bool profile_reduce_impl_impl(bool do_verification,
|
||||
arrOutLengths,
|
||||
arrOutStrides,
|
||||
reduceDims,
|
||||
alpha,
|
||||
beta,
|
||||
static_cast<double>(alpha),
|
||||
static_cast<double>(beta),
|
||||
in.mData.data(),
|
||||
nullptr,
|
||||
out_ref.mData.data(),
|
||||
@@ -361,8 +361,8 @@ bool profile_reduce_impl_impl(bool do_verification,
|
||||
arrOutLengths,
|
||||
arrOutStrides,
|
||||
reduceDims,
|
||||
alpha,
|
||||
beta,
|
||||
static_cast<double>(alpha),
|
||||
static_cast<double>(beta),
|
||||
in_dev.GetDeviceBuffer(),
|
||||
nullptr,
|
||||
out_dev.GetDeviceBuffer(),
|
||||
|
||||
@@ -48,8 +48,8 @@ bool profile_softmax_impl(int do_verification,
|
||||
std::vector<index_t> in_length,
|
||||
std::vector<index_t> in_strides,
|
||||
std::vector<index_t> reduce_dims,
|
||||
AccDataType alpha,
|
||||
AccDataType beta)
|
||||
double alpha,
|
||||
double beta)
|
||||
{
|
||||
if(Rank != in_length.size())
|
||||
{
|
||||
@@ -122,8 +122,8 @@ bool profile_softmax_impl(int do_verification,
|
||||
auto argument_ptr = inst_ptr->MakeArgumentPointer(in_tensor_lengths,
|
||||
in_tensor_strides,
|
||||
reduce_dims,
|
||||
&alpha,
|
||||
&beta,
|
||||
alpha,
|
||||
beta,
|
||||
in_dev.GetDeviceBuffer(),
|
||||
out_dev.GetDeviceBuffer(),
|
||||
PassThrough{},
|
||||
|
||||
Reference in New Issue
Block a user