Use double for all scaling values and float-point constant values at the Device Op API (#557)

* Use double as alpha/beta values type in reduce device op api

* Use double as alpha/beta values type in softmax device op api

* Use double as alpha/beta values type in multiple-reduce device op api

* Use double as epsilon value type in normalization/elementwise-normalization device op api
This commit is contained in:
Qianfeng
2023-01-19 02:02:50 +08:00
committed by GitHub
parent 1cfa87608a
commit 52abc2f371
24 changed files with 112 additions and 109 deletions

View File

@@ -47,8 +47,8 @@ int main(int argc, char* argv[])
ck::index_t num_elements =
std::accumulate(in_lengths.begin(), in_lengths.end(), 1, std::multiplies<ck::index_t>());
AccDataType alpha{2.0f};
AccDataType beta{2.0f};
double alpha{2.0};
double beta{2.0};
SimpleDeviceMem in(sizeof(InDataType) * num_elements);
SimpleDeviceMem out(sizeof(OutDataType) * num_elements);
@@ -82,8 +82,8 @@ int main(int argc, char* argv[])
auto argument_ptr = op_ptr->MakeArgumentPointer(in_lengths,
in_strides,
reduce_dims,
&alpha,
&beta,
alpha,
beta,
in.GetDeviceBuffer(),
out.GetDeviceBuffer(),
PassThrough{},
@@ -129,8 +129,8 @@ int main(int argc, char* argv[])
auto argument_ptr = op_ptr->MakeArgumentPointer(in_lengths,
in_strides,
reduce_dims,
&alpha,
&beta,
alpha,
beta,
in.GetDeviceBuffer(),
out.GetDeviceBuffer(),
PassThrough{},
@@ -147,4 +147,4 @@ int main(int argc, char* argv[])
}
return 0;
}
}