mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-18 03:49:41 +00:00
Merge commit 'b0a2d99d100f2e4212ebbed080acb49a404035ab' into develop
This commit is contained in:
@@ -22,6 +22,8 @@ using F32 = float;
|
||||
using ADataType = F16;
|
||||
using BDataType = F16;
|
||||
|
||||
using NchwLayout = ck::tensor_layout::convolution::NCHW;
|
||||
using NhwcLayout = ck::tensor_layout::convolution::NHWC;
|
||||
using UnaryScale = ck::tensor_operation::element_wise::Scale;
|
||||
using UnarySquare = ck::tensor_operation::element_wise::UnarySquare;
|
||||
using UnaryScaleSquare =
|
||||
@@ -73,11 +75,11 @@ int main(int argc, char* argv[])
|
||||
1};
|
||||
ck::ranges::copy(nchw, ab_lengths.begin());
|
||||
|
||||
std::array<Tensor<ADataType>, 2> as = {Tensor<ADataType>(ab_lengths, ab_strides),
|
||||
Tensor<ADataType>(ab_lengths, ab_strides)};
|
||||
std::array<Tensor<ADataType>, 2> as = {Tensor<ADataType>(ab_lengths, ab_strides, NchwLayout{}),
|
||||
Tensor<ADataType>(ab_lengths, ab_strides, NchwLayout{})};
|
||||
Tensor<ADataType>& a0 = as[0];
|
||||
Tensor<ADataType>& a1 = as[1];
|
||||
Tensor<BDataType> b(ab_lengths, ab_strides);
|
||||
Tensor<BDataType> b(ab_lengths, ab_strides, NchwLayout{});
|
||||
float alpha = 3.f;
|
||||
float beta = 2.f;
|
||||
a0.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
@@ -134,7 +136,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
Tensor<BDataType> host_b(ab_lengths, ab_strides);
|
||||
Tensor<BDataType> host_b(ab_lengths, ab_strides, NchwLayout{});
|
||||
|
||||
using ReferenceElementwiseInstance = ck::tensor_operation::host::
|
||||
ReferenceElementwise<2, ADataType, BDataType, BinaryAddUnaryScaleSquare>;
|
||||
|
||||
Reference in New Issue
Block a user