mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 10:09:41 +00:00
Use double for all scaling values and float-point constant values at the Device Op API (#557)
* Use double as alpha/beta values type in reduce device op api
* Use double as alpha/beta values type in softmax device op api
* Use double as alpha/beta values type in multiple-reduce device op api
* Use double as epsilon value type in normalization/elementwise-normalization device op api
[ROCm/composable_kernel commit: 52abc2f371]
This commit is contained in:
@@ -267,8 +267,8 @@ int reduce_blockwise_impl(bool do_verification,
|
||||
arrOutLengths,
|
||||
arrOutStrides,
|
||||
reduceDims,
|
||||
alpha,
|
||||
beta,
|
||||
static_cast<double>(alpha),
|
||||
static_cast<double>(beta),
|
||||
in.mData.data(),
|
||||
nullptr,
|
||||
out_ref.mData.data(),
|
||||
@@ -295,8 +295,8 @@ int reduce_blockwise_impl(bool do_verification,
|
||||
arrOutLengths,
|
||||
arrOutStrides,
|
||||
reduceDims,
|
||||
alpha,
|
||||
beta,
|
||||
static_cast<double>(alpha),
|
||||
static_cast<double>(beta),
|
||||
in_dev.GetDeviceBuffer(),
|
||||
nullptr,
|
||||
out_dev.GetDeviceBuffer(),
|
||||
|
||||
@@ -226,8 +226,8 @@ int main(int argc, char* argv[])
|
||||
arrOutLengths,
|
||||
arrOutStrides,
|
||||
reduceDims,
|
||||
alpha,
|
||||
beta,
|
||||
static_cast<double>(alpha),
|
||||
static_cast<double>(beta),
|
||||
in_1.mData.data(),
|
||||
nullptr,
|
||||
out_ref.mData.data(),
|
||||
@@ -254,8 +254,8 @@ int main(int argc, char* argv[])
|
||||
arrInLengths_2,
|
||||
arrInStrides_2,
|
||||
reduceDims_1,
|
||||
1.0f,
|
||||
0.0f,
|
||||
1.0,
|
||||
0.0,
|
||||
in_1_dev.GetDeviceBuffer(),
|
||||
nullptr,
|
||||
in_2_dev.GetDeviceBuffer(),
|
||||
@@ -278,8 +278,8 @@ int main(int argc, char* argv[])
|
||||
arrOutLengths,
|
||||
arrOutStrides,
|
||||
reduceDims_2,
|
||||
alpha,
|
||||
beta,
|
||||
static_cast<double>(alpha),
|
||||
static_cast<double>(beta),
|
||||
in_2_dev.GetDeviceBuffer(),
|
||||
nullptr,
|
||||
out_dev.GetDeviceBuffer(),
|
||||
|
||||
@@ -180,8 +180,8 @@ int reduce_multiblock_atomic_add_impl(bool do_verification,
|
||||
arrOutLengths,
|
||||
arrOutStrides,
|
||||
reduceDims,
|
||||
alpha,
|
||||
beta,
|
||||
static_cast<double>(alpha),
|
||||
static_cast<double>(beta),
|
||||
in.mData.data(),
|
||||
nullptr,
|
||||
out_ref.mData.data(),
|
||||
@@ -208,8 +208,8 @@ int reduce_multiblock_atomic_add_impl(bool do_verification,
|
||||
arrOutLengths,
|
||||
arrOutStrides,
|
||||
reduceDims,
|
||||
alpha,
|
||||
beta,
|
||||
static_cast<double>(alpha),
|
||||
static_cast<double>(beta),
|
||||
in_dev.GetDeviceBuffer(),
|
||||
nullptr,
|
||||
out_dev.GetDeviceBuffer(),
|
||||
|
||||
@@ -56,8 +56,8 @@ class SimpleAppArgs
|
||||
int option_index = 0;
|
||||
|
||||
public:
|
||||
std::vector<size_t> inLengths = {8, 128, 2048};
|
||||
std::vector<AccDataType> scales = {2.0f, 2.0f};
|
||||
std::vector<size_t> inLengths = {8, 128, 2048};
|
||||
std::vector<double> scales = {2.0, 2.0};
|
||||
|
||||
bool do_verification = true;
|
||||
int init_method = 2;
|
||||
@@ -151,8 +151,8 @@ int main(int argc, char* argv[])
|
||||
auto inStrides = in.mDesc.GetStrides();
|
||||
auto outStrides = out.mDesc.GetStrides();
|
||||
|
||||
AccDataType alpha = args.scales[0];
|
||||
AccDataType beta = args.scales[1];
|
||||
double alpha = args.scales[0];
|
||||
double beta = args.scales[1];
|
||||
|
||||
std::cout << "in: " << in.mDesc << std::endl;
|
||||
std::cout << "out: " << out.mDesc << std::endl;
|
||||
@@ -221,8 +221,8 @@ int main(int argc, char* argv[])
|
||||
auto argument_ptr = device_instance.MakeArgumentPointer(i_inLengths,
|
||||
i_inStrides,
|
||||
reduceDims,
|
||||
&alpha,
|
||||
&beta,
|
||||
alpha,
|
||||
beta,
|
||||
in_dev.GetDeviceBuffer(),
|
||||
out_dev.GetDeviceBuffer(),
|
||||
PassThrough{},
|
||||
|
||||
@@ -217,8 +217,8 @@ int mean_meansquare_dual_reduce_test(size_t n,
|
||||
size_t invariant_total_length = n;
|
||||
size_t reduce_total_length = h * w * c;
|
||||
|
||||
const AccDataType alpha = ck::type_convert<AccDataType>(1.0f);
|
||||
const AccDataType beta = ck::type_convert<AccDataType>(0.0f);
|
||||
const double alpha = 1.0f;
|
||||
const double beta = 0.0f;
|
||||
|
||||
std::size_t num_thread = 1;
|
||||
|
||||
@@ -267,8 +267,8 @@ int mean_meansquare_dual_reduce_test(size_t n,
|
||||
i_outLengths,
|
||||
{i_outStrides, i_outStrides},
|
||||
reduceDims,
|
||||
{&alpha, &alpha},
|
||||
{&beta, &beta},
|
||||
{alpha, alpha},
|
||||
{beta, beta},
|
||||
in_dev.GetDeviceBuffer(),
|
||||
{mean_dev.GetDeviceBuffer(), meansquare_dev.GetDeviceBuffer()},
|
||||
ck::make_tuple(InElementwiseOperation_Mean{}, InElementwiseOperation_Meansquare{}),
|
||||
|
||||
Reference in New Issue
Block a user