mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Use double for all scaling values and float-point constant values at the Device Op API (#557)
* Use double as alpha/beta values type in reduce device op api * Use double as alpha/beta values type in softmax device op api * Use double as alpha/beta values type in multiple-reduce device op api * Use double as epsilon value type in normalization/elementwise-normalization device op api
This commit is contained in:
@@ -47,8 +47,8 @@ int main(int argc, char* argv[])
|
||||
ck::index_t num_elements =
|
||||
std::accumulate(in_lengths.begin(), in_lengths.end(), 1, std::multiplies<ck::index_t>());
|
||||
|
||||
AccDataType alpha{2.0f};
|
||||
AccDataType beta{2.0f};
|
||||
double alpha{2.0};
|
||||
double beta{2.0};
|
||||
|
||||
SimpleDeviceMem in(sizeof(InDataType) * num_elements);
|
||||
SimpleDeviceMem out(sizeof(OutDataType) * num_elements);
|
||||
@@ -82,8 +82,8 @@ int main(int argc, char* argv[])
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer(in_lengths,
|
||||
in_strides,
|
||||
reduce_dims,
|
||||
&alpha,
|
||||
&beta,
|
||||
alpha,
|
||||
beta,
|
||||
in.GetDeviceBuffer(),
|
||||
out.GetDeviceBuffer(),
|
||||
PassThrough{},
|
||||
@@ -129,8 +129,8 @@ int main(int argc, char* argv[])
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer(in_lengths,
|
||||
in_strides,
|
||||
reduce_dims,
|
||||
&alpha,
|
||||
&beta,
|
||||
alpha,
|
||||
beta,
|
||||
in.GetDeviceBuffer(),
|
||||
out.GetDeviceBuffer(),
|
||||
PassThrough{},
|
||||
@@ -147,4 +147,4 @@ int main(int argc, char* argv[])
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user