mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Use double for all scaling values and float-point constant values at the Device Op API (#557)
* Use double as alpha/beta values type in reduce device op api * Use double as alpha/beta values type in softmax device op api * Use double as alpha/beta values type in multiple-reduce device op api * Use double as epsilon value type in normalization/elementwise-normalization device op api
This commit is contained in:
@@ -267,8 +267,8 @@ int reduce_blockwise_impl(bool do_verification,
|
||||
arrOutLengths,
|
||||
arrOutStrides,
|
||||
reduceDims,
|
||||
alpha,
|
||||
beta,
|
||||
static_cast<double>(alpha),
|
||||
static_cast<double>(beta),
|
||||
in.mData.data(),
|
||||
nullptr,
|
||||
out_ref.mData.data(),
|
||||
@@ -295,8 +295,8 @@ int reduce_blockwise_impl(bool do_verification,
|
||||
arrOutLengths,
|
||||
arrOutStrides,
|
||||
reduceDims,
|
||||
alpha,
|
||||
beta,
|
||||
static_cast<double>(alpha),
|
||||
static_cast<double>(beta),
|
||||
in_dev.GetDeviceBuffer(),
|
||||
nullptr,
|
||||
out_dev.GetDeviceBuffer(),
|
||||
|
||||
@@ -226,8 +226,8 @@ int main(int argc, char* argv[])
|
||||
arrOutLengths,
|
||||
arrOutStrides,
|
||||
reduceDims,
|
||||
alpha,
|
||||
beta,
|
||||
static_cast<double>(alpha),
|
||||
static_cast<double>(beta),
|
||||
in_1.mData.data(),
|
||||
nullptr,
|
||||
out_ref.mData.data(),
|
||||
@@ -254,8 +254,8 @@ int main(int argc, char* argv[])
|
||||
arrInLengths_2,
|
||||
arrInStrides_2,
|
||||
reduceDims_1,
|
||||
1.0f,
|
||||
0.0f,
|
||||
1.0,
|
||||
0.0,
|
||||
in_1_dev.GetDeviceBuffer(),
|
||||
nullptr,
|
||||
in_2_dev.GetDeviceBuffer(),
|
||||
@@ -278,8 +278,8 @@ int main(int argc, char* argv[])
|
||||
arrOutLengths,
|
||||
arrOutStrides,
|
||||
reduceDims_2,
|
||||
alpha,
|
||||
beta,
|
||||
static_cast<double>(alpha),
|
||||
static_cast<double>(beta),
|
||||
in_2_dev.GetDeviceBuffer(),
|
||||
nullptr,
|
||||
out_dev.GetDeviceBuffer(),
|
||||
|
||||
@@ -180,8 +180,8 @@ int reduce_multiblock_atomic_add_impl(bool do_verification,
|
||||
arrOutLengths,
|
||||
arrOutStrides,
|
||||
reduceDims,
|
||||
alpha,
|
||||
beta,
|
||||
static_cast<double>(alpha),
|
||||
static_cast<double>(beta),
|
||||
in.mData.data(),
|
||||
nullptr,
|
||||
out_ref.mData.data(),
|
||||
@@ -208,8 +208,8 @@ int reduce_multiblock_atomic_add_impl(bool do_verification,
|
||||
arrOutLengths,
|
||||
arrOutStrides,
|
||||
reduceDims,
|
||||
alpha,
|
||||
beta,
|
||||
static_cast<double>(alpha),
|
||||
static_cast<double>(beta),
|
||||
in_dev.GetDeviceBuffer(),
|
||||
nullptr,
|
||||
out_dev.GetDeviceBuffer(),
|
||||
|
||||
Reference in New Issue
Block a user