mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
Use double for all scaling values and float-point constant values at the Device Op API (#557)
* Use double as alpha/beta values type in reduce device op api
* Use double as alpha/beta values type in softmax device op api
* Use double as alpha/beta values type in multiple-reduce device op api
* Use double as epsilon value type in normalization/elementwise-normalization device op api
[ROCm/composable_kernel commit: 52abc2f371]
This commit is contained in:
@@ -47,8 +47,8 @@ int main(int argc, char* argv[])
|
||||
ck::index_t num_elements =
|
||||
std::accumulate(in_lengths.begin(), in_lengths.end(), 1, std::multiplies<ck::index_t>());
|
||||
|
||||
AccDataType alpha{2.0f};
|
||||
AccDataType beta{2.0f};
|
||||
double alpha{2.0};
|
||||
double beta{2.0};
|
||||
|
||||
SimpleDeviceMem in(sizeof(InDataType) * num_elements);
|
||||
SimpleDeviceMem out(sizeof(OutDataType) * num_elements);
|
||||
@@ -82,8 +82,8 @@ int main(int argc, char* argv[])
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer(in_lengths,
|
||||
in_strides,
|
||||
reduce_dims,
|
||||
&alpha,
|
||||
&beta,
|
||||
alpha,
|
||||
beta,
|
||||
in.GetDeviceBuffer(),
|
||||
out.GetDeviceBuffer(),
|
||||
PassThrough{},
|
||||
@@ -129,8 +129,8 @@ int main(int argc, char* argv[])
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer(in_lengths,
|
||||
in_strides,
|
||||
reduce_dims,
|
||||
&alpha,
|
||||
&beta,
|
||||
alpha,
|
||||
beta,
|
||||
in.GetDeviceBuffer(),
|
||||
out.GetDeviceBuffer(),
|
||||
PassThrough{},
|
||||
@@ -147,4 +147,4 @@ int main(int argc, char* argv[])
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -61,8 +61,8 @@ int main(int argc, char* argv[])
|
||||
for(auto dim : reduce_dims)
|
||||
reduce_length *= in_lengths[dim];
|
||||
|
||||
float alpha{1.0f};
|
||||
float beta{0.0f};
|
||||
double alpha{1.0};
|
||||
double beta{0.0};
|
||||
|
||||
SimpleDeviceMem in(sizeof(InDataType) * num_in_elements);
|
||||
SimpleDeviceMem out(sizeof(OutDataType) * num_out_elements);
|
||||
|
||||
@@ -267,8 +267,8 @@ int reduce_blockwise_impl(bool do_verification,
|
||||
arrOutLengths,
|
||||
arrOutStrides,
|
||||
reduceDims,
|
||||
alpha,
|
||||
beta,
|
||||
static_cast<double>(alpha),
|
||||
static_cast<double>(beta),
|
||||
in.mData.data(),
|
||||
nullptr,
|
||||
out_ref.mData.data(),
|
||||
@@ -295,8 +295,8 @@ int reduce_blockwise_impl(bool do_verification,
|
||||
arrOutLengths,
|
||||
arrOutStrides,
|
||||
reduceDims,
|
||||
alpha,
|
||||
beta,
|
||||
static_cast<double>(alpha),
|
||||
static_cast<double>(beta),
|
||||
in_dev.GetDeviceBuffer(),
|
||||
nullptr,
|
||||
out_dev.GetDeviceBuffer(),
|
||||
|
||||
@@ -226,8 +226,8 @@ int main(int argc, char* argv[])
|
||||
arrOutLengths,
|
||||
arrOutStrides,
|
||||
reduceDims,
|
||||
alpha,
|
||||
beta,
|
||||
static_cast<double>(alpha),
|
||||
static_cast<double>(beta),
|
||||
in_1.mData.data(),
|
||||
nullptr,
|
||||
out_ref.mData.data(),
|
||||
@@ -254,8 +254,8 @@ int main(int argc, char* argv[])
|
||||
arrInLengths_2,
|
||||
arrInStrides_2,
|
||||
reduceDims_1,
|
||||
1.0f,
|
||||
0.0f,
|
||||
1.0,
|
||||
0.0,
|
||||
in_1_dev.GetDeviceBuffer(),
|
||||
nullptr,
|
||||
in_2_dev.GetDeviceBuffer(),
|
||||
@@ -278,8 +278,8 @@ int main(int argc, char* argv[])
|
||||
arrOutLengths,
|
||||
arrOutStrides,
|
||||
reduceDims_2,
|
||||
alpha,
|
||||
beta,
|
||||
static_cast<double>(alpha),
|
||||
static_cast<double>(beta),
|
||||
in_2_dev.GetDeviceBuffer(),
|
||||
nullptr,
|
||||
out_dev.GetDeviceBuffer(),
|
||||
|
||||
@@ -180,8 +180,8 @@ int reduce_multiblock_atomic_add_impl(bool do_verification,
|
||||
arrOutLengths,
|
||||
arrOutStrides,
|
||||
reduceDims,
|
||||
alpha,
|
||||
beta,
|
||||
static_cast<double>(alpha),
|
||||
static_cast<double>(beta),
|
||||
in.mData.data(),
|
||||
nullptr,
|
||||
out_ref.mData.data(),
|
||||
@@ -208,8 +208,8 @@ int reduce_multiblock_atomic_add_impl(bool do_verification,
|
||||
arrOutLengths,
|
||||
arrOutStrides,
|
||||
reduceDims,
|
||||
alpha,
|
||||
beta,
|
||||
static_cast<double>(alpha),
|
||||
static_cast<double>(beta),
|
||||
in_dev.GetDeviceBuffer(),
|
||||
nullptr,
|
||||
out_dev.GetDeviceBuffer(),
|
||||
|
||||
@@ -56,8 +56,8 @@ class SimpleAppArgs
|
||||
int option_index = 0;
|
||||
|
||||
public:
|
||||
std::vector<size_t> inLengths = {8, 128, 2048};
|
||||
std::vector<AccDataType> scales = {2.0f, 2.0f};
|
||||
std::vector<size_t> inLengths = {8, 128, 2048};
|
||||
std::vector<double> scales = {2.0, 2.0};
|
||||
|
||||
bool do_verification = true;
|
||||
int init_method = 2;
|
||||
@@ -151,8 +151,8 @@ int main(int argc, char* argv[])
|
||||
auto inStrides = in.mDesc.GetStrides();
|
||||
auto outStrides = out.mDesc.GetStrides();
|
||||
|
||||
AccDataType alpha = args.scales[0];
|
||||
AccDataType beta = args.scales[1];
|
||||
double alpha = args.scales[0];
|
||||
double beta = args.scales[1];
|
||||
|
||||
std::cout << "in: " << in.mDesc << std::endl;
|
||||
std::cout << "out: " << out.mDesc << std::endl;
|
||||
@@ -221,8 +221,8 @@ int main(int argc, char* argv[])
|
||||
auto argument_ptr = device_instance.MakeArgumentPointer(i_inLengths,
|
||||
i_inStrides,
|
||||
reduceDims,
|
||||
&alpha,
|
||||
&beta,
|
||||
alpha,
|
||||
beta,
|
||||
in_dev.GetDeviceBuffer(),
|
||||
out_dev.GetDeviceBuffer(),
|
||||
PassThrough{},
|
||||
|
||||
@@ -217,8 +217,8 @@ int mean_meansquare_dual_reduce_test(size_t n,
|
||||
size_t invariant_total_length = n;
|
||||
size_t reduce_total_length = h * w * c;
|
||||
|
||||
const AccDataType alpha = ck::type_convert<AccDataType>(1.0f);
|
||||
const AccDataType beta = ck::type_convert<AccDataType>(0.0f);
|
||||
const double alpha = 1.0f;
|
||||
const double beta = 0.0f;
|
||||
|
||||
std::size_t num_thread = 1;
|
||||
|
||||
@@ -267,8 +267,8 @@ int mean_meansquare_dual_reduce_test(size_t n,
|
||||
i_outLengths,
|
||||
{i_outStrides, i_outStrides},
|
||||
reduceDims,
|
||||
{&alpha, &alpha},
|
||||
{&beta, &beta},
|
||||
{alpha, alpha},
|
||||
{beta, beta},
|
||||
in_dev.GetDeviceBuffer(),
|
||||
{mean_dev.GetDeviceBuffer(), meansquare_dev.GetDeviceBuffer()},
|
||||
ck::make_tuple(InElementwiseOperation_Mean{}, InElementwiseOperation_Meansquare{}),
|
||||
|
||||
@@ -32,7 +32,7 @@ struct DeviceElementwiseNormalization : public BaseOperator
|
||||
const std::vector<index_t> betaStrides,
|
||||
const std::vector<index_t> yStrides,
|
||||
const std::vector<index_t> reduceDims,
|
||||
AccDataType epsilon,
|
||||
double epsilon,
|
||||
const std::array<const void*, NumInput> in_dev_buffers,
|
||||
const void* p_gamma,
|
||||
const void* p_beta,
|
||||
|
||||
@@ -32,8 +32,8 @@ struct DeviceMultipleReduce : public BaseOperator
|
||||
const std::array<index_t, NumOutputDim> outLengths,
|
||||
const std::array<std::array<index_t, NumOutputDim>, NumReduction> outStrides,
|
||||
const std::array<int, NumReduceDim> reduceDims,
|
||||
const std::array<const void*, NumReduction> alphas,
|
||||
const std::array<const void*, NumReduction> betas,
|
||||
const std::array<double, NumReduction> alphas,
|
||||
const std::array<double, NumReduction> betas,
|
||||
const void* in_dev,
|
||||
const std::array<void*, NumReduction> out_dev_buffers,
|
||||
const InElementwiseOperationTuple in_elementwise_op_tuple,
|
||||
|
||||
@@ -28,7 +28,7 @@ struct DeviceNormalization : public BaseOperator
|
||||
const std::vector<index_t> betaStrides,
|
||||
const std::vector<index_t> yStrides,
|
||||
const std::vector<index_t> reduceDims,
|
||||
AccDataType epsilon,
|
||||
double epsilon,
|
||||
const void* p_x,
|
||||
const void* p_gamma,
|
||||
const void* p_beta,
|
||||
|
||||
@@ -33,8 +33,8 @@ struct DeviceReduce : public BaseOperator
|
||||
const std::array<index_t, NumOutDim> outLengths,
|
||||
const std::array<index_t, NumOutDim> outStrides,
|
||||
const std::array<int, NumReduceDim> reduceDims,
|
||||
float alpha,
|
||||
float beta,
|
||||
double alpha,
|
||||
double beta,
|
||||
const void* in_dev,
|
||||
const void* in_index_dev,
|
||||
void* out_dev,
|
||||
|
||||
@@ -27,10 +27,8 @@ struct DeviceSoftmax : public BaseOperator
|
||||
// @param[in] inLengths Input tensor extent(s) from high to low dimension
|
||||
// @param[in] inStrides Input tensor stride(s) from high to low dimension
|
||||
// @param[in] reduceDims The dimension(s) the normalization operation is applied
|
||||
// @param[in] alpha Typeless pointer in host memory storing the alpha scaling
|
||||
// value as type AccDataType
|
||||
// @param[in] beta Typeless pointer in host memory storing the beta scaling
|
||||
// value as type AccDataType
|
||||
// @param[in] alpha double type value
|
||||
// @param[in] beta double type value
|
||||
// @param[in] in_dev Typeless const pointer in device memory storing the input
|
||||
// tensor
|
||||
// @param out_dev Typeless pointer in device memory storing the output tensor
|
||||
@@ -43,8 +41,8 @@ struct DeviceSoftmax : public BaseOperator
|
||||
MakeArgumentPointer(const std::vector<index_t> inLengths,
|
||||
const std::vector<index_t> inStrides,
|
||||
const std::vector<int> reduceDims,
|
||||
const void* alpha,
|
||||
const void* beta,
|
||||
double alpha,
|
||||
double beta,
|
||||
const void* in_dev,
|
||||
void* out_dev,
|
||||
InElementwiseOp in_elementwise_op,
|
||||
|
||||
@@ -270,18 +270,18 @@ struct DeviceElementwiseNormalizationImpl
|
||||
const std::vector<index_t> reduceDims,
|
||||
XElementwiseOperation x_elementwise_op,
|
||||
YElementwiseOperation y_elementwise_op,
|
||||
AccDataType epsilon,
|
||||
double epsilon,
|
||||
const std::array<const void*, NumInput> in_dev_buffers,
|
||||
const GammaDataType* p_gamma,
|
||||
const BetaDataType* p_beta,
|
||||
YDataType* p_y)
|
||||
: epsilon_(epsilon),
|
||||
p_gamma_(p_gamma),
|
||||
: p_gamma_(p_gamma),
|
||||
p_beta_(p_beta),
|
||||
p_y_(p_y),
|
||||
x_elementwise_op_(x_elementwise_op),
|
||||
y_elementwise_op_(y_elementwise_op)
|
||||
{
|
||||
epsilon_ = static_cast<AccDataType>(epsilon);
|
||||
|
||||
Lengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(lengths, reduceDims);
|
||||
for(int i = 0; i < NumInput; i++)
|
||||
@@ -543,7 +543,7 @@ struct DeviceElementwiseNormalizationImpl
|
||||
const std::vector<index_t> betaStrides,
|
||||
const std::vector<index_t> yStrides,
|
||||
const std::vector<index_t> reduceDims,
|
||||
AccDataType epsilon,
|
||||
double epsilon,
|
||||
const std::array<const void*, NumInput> in_dev_buffers,
|
||||
const void* p_gamma,
|
||||
const void* p_beta,
|
||||
|
||||
@@ -270,8 +270,8 @@ struct DeviceMultipleReduceMultiBlock : public DeviceMultipleReduce<Rank,
|
||||
const std::array<index_t, NumOutputDim>& outLengths,
|
||||
const std::array<std::array<index_t, NumOutputDim>, NumReduction>& outStridesArray,
|
||||
const std::array<int, NumReduceDim>& reduceDims,
|
||||
const std::array<const void*, NumReduction>& alphas,
|
||||
const std::array<const void*, NumReduction>& betas,
|
||||
const std::array<double, NumReduction>& alphas,
|
||||
const std::array<double, NumReduction>& betas,
|
||||
const void* in_dev,
|
||||
const std::array<void*, NumReduction>& out_dev_buffers,
|
||||
const InElementwiseOperationTuple in_elementwise_op_tuple,
|
||||
@@ -286,8 +286,8 @@ struct DeviceMultipleReduceMultiBlock : public DeviceMultipleReduce<Rank,
|
||||
|
||||
for(size_t i = 0; i < NumReduction; i++)
|
||||
{
|
||||
alpha_values_(i) = *static_cast<const AccDataType*>(alphas[i]);
|
||||
beta_values_(i) = *static_cast<const AccDataType*>(betas[i]);
|
||||
alpha_values_(i) = static_cast<AccDataType>(alphas[i]);
|
||||
beta_values_(i) = static_cast<AccDataType>(betas[i]);
|
||||
};
|
||||
|
||||
in_dev_ = static_cast<const InDataType*>(in_dev);
|
||||
@@ -547,8 +547,8 @@ struct DeviceMultipleReduceMultiBlock : public DeviceMultipleReduce<Rank,
|
||||
const std::array<index_t, NumOutputDim> outLengths,
|
||||
const std::array<std::array<index_t, NumOutputDim>, NumReduction> outStridesArray,
|
||||
const std::array<int, NumReduceDim> reduceDims,
|
||||
const std::array<const void*, NumReduction> alphas,
|
||||
const std::array<const void*, NumReduction> betas,
|
||||
const std::array<double, NumReduction> alphas,
|
||||
const std::array<double, NumReduction> betas,
|
||||
const void* in_dev,
|
||||
const std::array<void*, NumReduction> out_dev_buffers,
|
||||
const InElementwiseOperationTuple in_elementwise_op_tuple,
|
||||
|
||||
@@ -195,8 +195,8 @@ struct DeviceMultipleReduceThreadWise : public DeviceMultipleReduce<Rank,
|
||||
const std::array<index_t, NumOutputDim>& outLengths,
|
||||
const std::array<std::array<index_t, NumOutputDim>, NumReduction>& outStridesArray,
|
||||
const std::array<int, NumReduceDim>& reduceDims,
|
||||
const std::array<const void*, NumReduction>& alphas,
|
||||
const std::array<const void*, NumReduction>& betas,
|
||||
const std::array<double, NumReduction>& alphas,
|
||||
const std::array<double, NumReduction>& betas,
|
||||
const void* in_dev,
|
||||
const std::array<void*, NumReduction>& out_dev_buffers,
|
||||
const InElementwiseOperationTuple in_elementwise_op_tuple,
|
||||
@@ -211,8 +211,8 @@ struct DeviceMultipleReduceThreadWise : public DeviceMultipleReduce<Rank,
|
||||
|
||||
for(size_t i = 0; i < NumReduction; i++)
|
||||
{
|
||||
alpha_values_(i) = *static_cast<const AccDataType*>(alphas[i]);
|
||||
beta_values_(i) = *static_cast<const AccDataType*>(betas[i]);
|
||||
alpha_values_(i) = static_cast<AccDataType>(alphas[i]);
|
||||
beta_values_(i) = static_cast<AccDataType>(betas[i]);
|
||||
};
|
||||
|
||||
in_dev_ = static_cast<const InDataType*>(in_dev);
|
||||
@@ -374,8 +374,8 @@ struct DeviceMultipleReduceThreadWise : public DeviceMultipleReduce<Rank,
|
||||
const std::array<index_t, NumOutputDim> outLengths,
|
||||
const std::array<std::array<index_t, NumOutputDim>, NumReduction> outStridesArray,
|
||||
const std::array<int, NumReduceDim> reduceDims,
|
||||
const std::array<const void*, NumReduction> alphas,
|
||||
const std::array<const void*, NumReduction> betas,
|
||||
const std::array<double, NumReduction> alphas,
|
||||
const std::array<double, NumReduction> betas,
|
||||
const void* in_dev,
|
||||
const std::array<void*, NumReduction> out_dev_buffers,
|
||||
const InElementwiseOperationTuple in_elementwise_op_tuple,
|
||||
|
||||
@@ -221,18 +221,19 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
|
||||
const std::vector<index_t> yStrides,
|
||||
const std::vector<index_t> reduceDims,
|
||||
AccElementwiseOperation acc_elementwise_op,
|
||||
AccDataType epsilon,
|
||||
double epsilon,
|
||||
const XDataType* p_x,
|
||||
const GammaDataType* p_gamma,
|
||||
const BetaDataType* p_beta,
|
||||
YDataType* p_y)
|
||||
: epsilon_(epsilon),
|
||||
p_x_(p_x),
|
||||
: p_x_(p_x),
|
||||
p_gamma_(p_gamma),
|
||||
p_beta_(p_beta),
|
||||
p_y_(p_y),
|
||||
acc_elementwise_op_(acc_elementwise_op)
|
||||
{
|
||||
epsilon_ = static_cast<AccDataType>(epsilon);
|
||||
|
||||
Lengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(lengths, reduceDims);
|
||||
xStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(xStrides, reduceDims);
|
||||
yStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(yStrides, reduceDims);
|
||||
@@ -421,7 +422,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
|
||||
const std::vector<index_t> betaStrides,
|
||||
const std::vector<index_t> yStrides,
|
||||
const std::vector<index_t> reduceDims,
|
||||
AccDataType epsilon,
|
||||
double epsilon,
|
||||
const void* p_x,
|
||||
const void* p_gamma,
|
||||
const void* p_beta,
|
||||
|
||||
@@ -217,8 +217,8 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InDataType,
|
||||
const std::array<index_t, NumDstDim> outLengths,
|
||||
const std::array<index_t, NumDstDim> outStrides,
|
||||
const std::array<int, NumReduceDim> reduceDims,
|
||||
float alpha,
|
||||
float beta,
|
||||
double alpha,
|
||||
double beta,
|
||||
const InDataType* in_dev,
|
||||
const IndexDataType* in_index_dev,
|
||||
OutDataType* out_dev,
|
||||
@@ -502,8 +502,8 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InDataType,
|
||||
const std::array<index_t, NumDstDim> outLengths,
|
||||
const std::array<index_t, NumDstDim> outStrides,
|
||||
const std::array<int, NumReduceDim> reduceDims,
|
||||
float alpha,
|
||||
float beta,
|
||||
double alpha,
|
||||
double beta,
|
||||
const void* in_dev,
|
||||
const void* in_index_dev,
|
||||
void* out_dev,
|
||||
|
||||
@@ -165,8 +165,8 @@ struct DeviceReduceThreadWise : public DeviceReduce<InDataType,
|
||||
const std::array<index_t, NumDstDim> outLengths,
|
||||
const std::array<index_t, NumDstDim> outStrides,
|
||||
const std::array<int, NumReduceDim> reduceDims,
|
||||
float alpha,
|
||||
float beta,
|
||||
double alpha,
|
||||
double beta,
|
||||
const InDataType* in_dev,
|
||||
OutDataType* out_dev,
|
||||
IndexDataType* out_index_dev,
|
||||
@@ -341,8 +341,8 @@ struct DeviceReduceThreadWise : public DeviceReduce<InDataType,
|
||||
const std::array<index_t, NumDstDim> outLengths,
|
||||
const std::array<index_t, NumDstDim> outStrides,
|
||||
const std::array<int, NumReduceDim> reduceDims,
|
||||
float alpha,
|
||||
float beta,
|
||||
double alpha,
|
||||
double beta,
|
||||
const void* in_dev,
|
||||
const void* in_index_dev,
|
||||
void* out_dev,
|
||||
|
||||
@@ -156,19 +156,20 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
|
||||
Argument(const std::vector<index_t> inLengths,
|
||||
const std::vector<index_t> inStrides,
|
||||
const std::vector<index_t> reduceDims,
|
||||
AccDataType alpha,
|
||||
AccDataType beta,
|
||||
double alpha,
|
||||
double beta,
|
||||
const InDataType* in_dev,
|
||||
OutDataType* out_dev,
|
||||
InElementwiseOp in_elementwise_op,
|
||||
AccElementwiseOp acc_elementwise_op)
|
||||
: alpha_{alpha},
|
||||
beta_{beta},
|
||||
in_dev_{in_dev},
|
||||
: in_dev_{in_dev},
|
||||
out_dev_{out_dev},
|
||||
in_elementwise_op_{in_elementwise_op},
|
||||
acc_elementwise_op_{acc_elementwise_op}
|
||||
{
|
||||
alpha_ = static_cast<AccDataType>(alpha);
|
||||
beta_ = static_cast<AccDataType>(beta);
|
||||
|
||||
if(Rank != inLengths.size() || Rank != inStrides.size() ||
|
||||
NumReduceDim != reduceDims.size())
|
||||
{
|
||||
@@ -336,8 +337,8 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
|
||||
static auto MakeArgument(const std::vector<index_t> inLengths,
|
||||
const std::vector<index_t> inStrides,
|
||||
const std::vector<int> reduceDims,
|
||||
const AccDataType alpha,
|
||||
const AccDataType beta,
|
||||
double alpha,
|
||||
double beta,
|
||||
const InDataType* in_dev,
|
||||
OutDataType* out_dev,
|
||||
InElementwiseOp in_elementwise_op,
|
||||
@@ -375,8 +376,8 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(const std::vector<index_t> inLengths,
|
||||
const std::vector<index_t> inStrides,
|
||||
const std::vector<int> reduceDims,
|
||||
const void* alpha,
|
||||
const void* beta,
|
||||
double alpha,
|
||||
double beta,
|
||||
const void* in_dev,
|
||||
void* out_dev,
|
||||
InElementwiseOp in_elementwise_op,
|
||||
@@ -385,8 +386,8 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
|
||||
return std::make_unique<Argument>(inLengths,
|
||||
inStrides,
|
||||
reduceDims,
|
||||
*static_cast<const AccDataType*>(alpha),
|
||||
*static_cast<const AccDataType*>(beta),
|
||||
alpha,
|
||||
beta,
|
||||
static_cast<const InDataType*>(in_dev),
|
||||
static_cast<OutDataType*>(out_dev),
|
||||
in_elementwise_op,
|
||||
|
||||
@@ -56,8 +56,8 @@ struct ReferenceReduce : public device::DeviceReduce<InDataType,
|
||||
const std::array<index_t, NumDstDim> outLengths,
|
||||
const std::array<index_t, NumDstDim> outStrides,
|
||||
const std::array<int, NumReduceDim> reduceDims,
|
||||
float alpha,
|
||||
float beta,
|
||||
double alpha,
|
||||
double beta,
|
||||
const InDataType* in_host,
|
||||
OutDataType* out_host,
|
||||
IndexDataType* out_index_host,
|
||||
@@ -388,8 +388,8 @@ struct ReferenceReduce : public device::DeviceReduce<InDataType,
|
||||
const std::array<index_t, NumDstDim> outLengths,
|
||||
const std::array<index_t, NumDstDim> outStrides,
|
||||
const std::array<int, NumReduceDim> reduceDims,
|
||||
float alpha,
|
||||
float beta,
|
||||
double alpha,
|
||||
double beta,
|
||||
const void* in_host,
|
||||
const void* in_index_host,
|
||||
void* out_host,
|
||||
|
||||
@@ -24,11 +24,14 @@ struct ReferenceSoftmax : public device::BaseOperator
|
||||
{
|
||||
Argument(const Tensor<InDataType>& in,
|
||||
Tensor<OutDataType>& out,
|
||||
AccDataType alpha,
|
||||
AccDataType beta,
|
||||
double alpha,
|
||||
double beta,
|
||||
const std::vector<index_t> sm_reduce_dims)
|
||||
: in_(in), out_(out), alpha_(alpha), beta_(beta), sm_reduce_dims_(sm_reduce_dims)
|
||||
: in_(in), out_(out), sm_reduce_dims_(sm_reduce_dims)
|
||||
{
|
||||
alpha_ = static_cast<AccDataType>(alpha);
|
||||
beta_ = static_cast<AccDataType>(beta);
|
||||
|
||||
// std::cout << "debug: scalar dims: ";
|
||||
for(size_t i = 0; i < in.mDesc.GetNumOfDimension(); i++)
|
||||
{
|
||||
@@ -143,8 +146,8 @@ struct ReferenceSoftmax : public device::BaseOperator
|
||||
|
||||
static auto MakeArgument(const Tensor<InDataType>& in,
|
||||
Tensor<OutDataType>& out,
|
||||
AccDataType alpha,
|
||||
AccDataType beta,
|
||||
double alpha,
|
||||
double beta,
|
||||
const std::vector<index_t> sm_reduce_dims)
|
||||
{
|
||||
return Argument{in, out, alpha, beta, sm_reduce_dims};
|
||||
|
||||
@@ -332,8 +332,8 @@ bool profile_reduce_impl_impl(bool do_verification,
|
||||
arrOutLengths,
|
||||
arrOutStrides,
|
||||
reduceDims,
|
||||
alpha,
|
||||
beta,
|
||||
static_cast<double>(alpha),
|
||||
static_cast<double>(beta),
|
||||
in.mData.data(),
|
||||
nullptr,
|
||||
out_ref.mData.data(),
|
||||
@@ -361,8 +361,8 @@ bool profile_reduce_impl_impl(bool do_verification,
|
||||
arrOutLengths,
|
||||
arrOutStrides,
|
||||
reduceDims,
|
||||
alpha,
|
||||
beta,
|
||||
static_cast<double>(alpha),
|
||||
static_cast<double>(beta),
|
||||
in_dev.GetDeviceBuffer(),
|
||||
nullptr,
|
||||
out_dev.GetDeviceBuffer(),
|
||||
|
||||
@@ -48,8 +48,8 @@ bool profile_softmax_impl(int do_verification,
|
||||
std::vector<index_t> in_length,
|
||||
std::vector<index_t> in_strides,
|
||||
std::vector<index_t> reduce_dims,
|
||||
AccDataType alpha,
|
||||
AccDataType beta)
|
||||
double alpha,
|
||||
double beta)
|
||||
{
|
||||
if(Rank != in_length.size())
|
||||
{
|
||||
@@ -122,8 +122,8 @@ bool profile_softmax_impl(int do_verification,
|
||||
auto argument_ptr = inst_ptr->MakeArgumentPointer(in_tensor_lengths,
|
||||
in_tensor_strides,
|
||||
reduce_dims,
|
||||
&alpha,
|
||||
&beta,
|
||||
alpha,
|
||||
beta,
|
||||
in_dev.GetDeviceBuffer(),
|
||||
out_dev.GetDeviceBuffer(),
|
||||
PassThrough{},
|
||||
|
||||
@@ -99,8 +99,8 @@ int profile_softmax(int argc, char* argv[])
|
||||
length,
|
||||
stride,
|
||||
reduce,
|
||||
float(alpha),
|
||||
float(beta));
|
||||
double(alpha),
|
||||
double(beta));
|
||||
}
|
||||
else if(data_type == SoftmaxDataType::F32_F32)
|
||||
{
|
||||
@@ -111,8 +111,8 @@ int profile_softmax(int argc, char* argv[])
|
||||
length,
|
||||
stride,
|
||||
reduce,
|
||||
float(alpha),
|
||||
float(beta));
|
||||
double(alpha),
|
||||
double(beta));
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -131,8 +131,8 @@ int profile_softmax(int argc, char* argv[])
|
||||
length,
|
||||
stride,
|
||||
reduce,
|
||||
float(alpha),
|
||||
float(beta));
|
||||
double(alpha),
|
||||
double(beta));
|
||||
}
|
||||
else if(data_type == SoftmaxDataType::F32_F32)
|
||||
{
|
||||
@@ -143,8 +143,8 @@ int profile_softmax(int argc, char* argv[])
|
||||
length,
|
||||
stride,
|
||||
reduce,
|
||||
float(alpha),
|
||||
float(beta));
|
||||
double(alpha),
|
||||
double(beta));
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user