From d6361ee5ed7d9972e956097906da521adc05825a Mon Sep 17 00:00:00 2001 From: kabrahamAMD Date: Tue, 30 Dec 2025 19:03:05 +0100 Subject: [PATCH] [CK_Builder] [testing] Integrate device random generators (#3427) Implemented device random number generators for ck tensors. Includes tests and integration to ck builder testing interface. [ROCm/composable_kernel commit: f86bbb1aefdd047b2b0e886dda831417e790f622] --- .../ck_tile/builder/testing/conv_fwd.hpp | 15 ++ .../builder/testing/tensor_initialization.hpp | 82 +++++++ .../ck_tile/builder/testing/testing.hpp | 14 ++ .../conv/ck/test_ckb_conv_fwd_2d_fp16.cpp | 2 + include/ck/library/utility/device_memory.hpp | 52 ++++ .../utility/device_tensor_generator.hpp | 135 +++++++++++ test/CMakeLists.txt | 1 + test/device_memory/CMakeLists.txt | 7 + test/device_memory/test_device_prng.cpp | 227 ++++++++++++++++++ 9 files changed, 535 insertions(+) create mode 100644 experimental/builder/include/ck_tile/builder/testing/tensor_initialization.hpp create mode 100644 include/ck/library/utility/device_tensor_generator.hpp create mode 100644 test/device_memory/CMakeLists.txt create mode 100644 test/device_memory/test_device_prng.cpp diff --git a/experimental/builder/include/ck_tile/builder/testing/conv_fwd.hpp b/experimental/builder/include/ck_tile/builder/testing/conv_fwd.hpp index f329a8a4d3..62d265894a 100644 --- a/experimental/builder/include/ck_tile/builder/testing/conv_fwd.hpp +++ b/experimental/builder/include/ck_tile/builder/testing/conv_fwd.hpp @@ -9,6 +9,7 @@ #include "ck_tile/builder/testing/testing.hpp" #include "ck_tile/builder/testing/extent.hpp" #include "ck_tile/builder/testing/tensor_buffer.hpp" +#include "ck_tile/builder/testing/tensor_initialization.hpp" #include "ck/library/utility/convolution_parameter.hpp" #include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" /// This file implements common functionality for invoking/testing grouped @@ -238,6 +239,20 @@ UniqueInputs alloc_inputs(const Args& args) }; } +/// @brief `init_inputs()` specialization for forward convolution. +/// +/// @tparam SIGNATURE Forward convolution signature. +/// +/// @see alloc_inputs() +template + requires ValidConvSignature && ConvDirectionIsForward && + ValidUniqueInputs +void init_inputs(const Args& args, UniqueInputs& inputs) +{ + init_tensor_buffer_uniform_fp(inputs.input_buf, args.make_input_descriptor(), -2.0f, 2.0f); + init_tensor_buffer_uniform_fp(inputs.weight_buf, args.make_weight_descriptor(), -2.0f, 2.0f); +} + /// @brief `alloc_outputs()` specialization for forward convolution. /// /// @tparam SIGNATURE Forward convolution signature. diff --git a/experimental/builder/include/ck_tile/builder/testing/tensor_initialization.hpp b/experimental/builder/include/ck_tile/builder/testing/tensor_initialization.hpp new file mode 100644 index 0000000000..15cb43f369 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/testing/tensor_initialization.hpp @@ -0,0 +1,82 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include +#include +#include +#include +#include +#include "ck_tile/builder/conv_signature_concepts.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp" +#include "ck_tile/builder/testing/type_traits.hpp" +#include "ck_tile/host/host_tensor.hpp" +#include "ck/utility/data_type.hpp" + +#include "ck/library/utility/device_tensor_generator.hpp" + +namespace ck_tile::builder::test { + +template +void init_tensor_buffer_uniform_int(const DeviceBuffer& buf, + const TensorDescriptor
& descriptor, + int min_val, + int max_val) +{ + size_t size = descriptor.get_element_space_size_in_bytes(); + + if(max_val - min_val <= 1) + { + throw std::runtime_error("Error while filling device tensor with random integer data: max " + "value must be at least 2 greater than min value, otherwise " + "tensor will be filled by a constant value (end is exclusive)"); + } + + using ck_type = factory::internal::DataTypeToCK
::type; + + // we might be asked to generate int values on fp data types that don't have the required + // precision + if(static_cast(max_val - 1) == static_cast(min_val)) + { + throw std::runtime_error("Error while filling device tensor with random integer data: " + "insufficient precision in specified range"); + } + size_t packed_size = ck::packed_size_v; + fill_tensor_uniform_rand_int_values<<<256, 256>>>( + static_cast(buf.get()), min_val, max_val, (size * packed_size) / sizeof(ck_type)); +} + +template +void init_tensor_buffer_uniform_fp(const DeviceBuffer& buf, + const TensorDescriptor
& descriptor, + float min_value, + float max_value) +{ + size_t size = descriptor.get_element_space_size_in_bytes(); + + using ck_type = factory::internal::DataTypeToCK
::type; + + size_t packed_size = ck::packed_size_v; + fill_tensor_uniform_rand_fp_values<<<256, 256>>>(reinterpret_cast(buf.get()), + min_value, + max_value, + (size * packed_size) / sizeof(ck_type)); +} + +template +void init_tensor_buffer_normal_fp(const DeviceBuffer& buf, + const TensorDescriptor
& descriptor, + float sigma, + float mean) +{ + size_t size = descriptor.get_element_space_size_in_bytes(); + + using ck_type = factory::internal::DataTypeToCK
::type; + size_t packed_size = ck::packed_size_v; + fill_tensor_norm_rand_fp_values<<<256, 256>>>( + static_cast(buf.get()), sigma, mean, (size * packed_size) / sizeof(ck_type)); +} + +} // namespace ck_tile::builder::test diff --git a/experimental/builder/include/ck_tile/builder/testing/testing.hpp b/experimental/builder/include/ck_tile/builder/testing/testing.hpp index 1873af2882..a0dfa27409 100644 --- a/experimental/builder/include/ck_tile/builder/testing/testing.hpp +++ b/experimental/builder/include/ck_tile/builder/testing/testing.hpp @@ -205,6 +205,20 @@ template requires ValidUniqueInputs UniqueInputs alloc_inputs(const Args& args); +/// @brief Allocate inputs corresponding to a signature. +/// +/// The `init_inputs()` function is used to initialize pseudo-random data +/// to the tensors specified in the Inputs structure. +/// +/// @tparam SIGNATURE the signature to specialize the structure for. +/// +/// @see Inputs +/// @see UniqueInputs +/// @see tensor_initialization +template + requires ValidUniqueInputs +void init_inputs(const Args& args, UniqueInputs& inputs); + /// @brief Allocate outputs corresponding to a signature. /// /// The `alloc_outputs()` function is used to create an instance of diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp16.cpp index b7eacf5643..aa53aa9666 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp16.cpp @@ -81,6 +81,8 @@ TEST(Fwd2DFp16_CShufV3_GNHWC, EndToEnd) auto inputs = alloc_inputs(args); auto outputs = alloc_outputs(args); + init_inputs(args, inputs); + auto conv = Instance{}; ckt::run(conv, args, inputs.get(), outputs.get()); } diff --git a/include/ck/library/utility/device_memory.hpp b/include/ck/library/utility/device_memory.hpp index b0ee766ff5..af5cb6ec28 100644 --- a/include/ck/library/utility/device_memory.hpp +++ b/include/ck/library/utility/device_memory.hpp @@ -4,6 +4,8 @@ #pragma once #include +#include +#include "ck/library/utility/device_tensor_generator.hpp" namespace ck { @@ -34,6 +36,12 @@ struct DeviceMem void SetZero() const; template void SetValue(T x) const; + template + void FillUniformRandInteger(int min_value, int max_value); + template + void FillUniformRandFp(float min_value, float max_value); + template + void FillNormalRandFp(float sigma, float mean); ~DeviceMem(); void* mpDeviceBuf; @@ -51,4 +59,48 @@ void DeviceMem::SetValue(T x) const set_buffer_value<<<1, 1024>>>(static_cast(mpDeviceBuf), x, mMemSize / sizeof(T)); } +template +void DeviceMem::FillUniformRandInteger(int min_value, int max_value) +{ + if(mMemSize % sizeof(T) != 0) + { + throw std::runtime_error("wrong! not entire DeviceMem will be filled"); + } + if(max_value - min_value <= 1) + { + throw std::runtime_error("Error while filling device tensor with random integer data: max " + "value must be at least 2 greater than min value, otherwise " + "tensor will be filled by a constant value (end is exclusive)"); + } + if(max_value - 1 == min_value || max_value - 1 == max_value) + { + throw std::runtime_error("Error while filling device tensor with random integer data: " + "insufficient precision in specified range"); + } + + size_t packed_size = packed_size_v; + fill_tensor_uniform_rand_int_values<<<256, 256>>>( + static_cast(mpDeviceBuf), min_value, max_value, (mMemSize * packed_size) / sizeof(T)); +} + +template +void DeviceMem::FillUniformRandFp(float min_value, float max_value) +{ + if(mMemSize % sizeof(T) != 0) + { + throw std::runtime_error("wrong! not entire DeviceMem will be filled"); + } + + size_t packed_size = packed_size_v; + fill_tensor_uniform_rand_fp_values<<<256, 256>>>( + static_cast(mpDeviceBuf), min_value, max_value, (mMemSize * packed_size) / sizeof(T)); +} + +template +void DeviceMem::FillNormalRandFp(float sigma, float mean) +{ + + fill_tensor_norm_rand_fp_values<<<256, 256>>>( + static_cast(mpDeviceBuf), sigma, mean, mMemSize / sizeof(T)); +} } // namespace ck diff --git a/include/ck/library/utility/device_tensor_generator.hpp b/include/ck/library/utility/device_tensor_generator.hpp new file mode 100644 index 0000000000..4da38bf399 --- /dev/null +++ b/include/ck/library/utility/device_tensor_generator.hpp @@ -0,0 +1,135 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#pragma once +#include + +#include "ck/ck.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/library/utility/device_tensor_generator.hpp" +#include "ck/utility/data_type.hpp" +#include + +// use xorshift for now since it is simple. Should be suitable enough, but feel free to switch in +// the future +struct ran_state_u32 +{ + uint32_t s[4]; +}; + +__device__ uint32_t ran_gen_round_u32(ran_state_u32& state) +{ + uint32_t tmp = state.s[3]; + state.s[3] = state.s[2]; + state.s[2] = state.s[1]; + state.s[1] = state.s[0]; + tmp ^= tmp << 11; + tmp ^= tmp >> 8; + state.s[0] = tmp ^ state.s[0] ^ (state.s[0] >> 19); + return state.s[0]; +} + +__device__ ran_state_u32 ran_init(uint32_t seed = 0) +{ + ran_state_u32 state; + // use primes for initialization + state.s[0] = (blockDim.x * blockIdx.x + threadIdx.x) * 8912741 + 2313212 + seed; + state.s[1] = + (gridDim.x * blockDim.x - (blockDim.x * blockIdx.x + threadIdx.x)) * 5013829 + 6012697; + state.s[2] = (blockDim.x * blockIdx.x + threadIdx.x) * 3412309 + 2912479; + state.s[3] = + (gridDim.x * blockDim.x - (blockDim.x * blockIdx.x + threadIdx.x)) * 1001447 + 9912307; + + // run 20 rounds + for(int i = 0; i < 20; i++) + { + ran_gen_round_u32(state); + } + return state; +} + +template +__global__ void fill_tensor_uniform_rand_int_values(T* p, + int min_value, + int max_value, + uint64_t buffer_element_size) +{ + // initial values + ran_state_u32 s = ran_init(); + for(uint64_t i = blockIdx.x * blockDim.x + threadIdx.x; + i < buffer_element_size / ck::packed_size_v; + i += blockDim.x * gridDim.x) + { + if constexpr(ck::is_same_v) + { + uint8_t hi = ((ran_gen_round_u32(s)) % (max_value - min_value)) + min_value + 8; + uint8_t lo = ((ran_gen_round_u32(s)) % (max_value - min_value)) + min_value + 8; + ck::pk_i4_t res = ((hi & 0xf) << 4) + (lo & 0xf); + p[i] = res; + } + else + { + p[i] = ck::type_convert( + static_cast((ran_gen_round_u32(s)) % (max_value - min_value)) + min_value); + } + } +} + +template +__global__ void fill_tensor_uniform_rand_fp_values(T* p, + float min_value, + float max_value, + uint64_t buffer_element_size) +{ + // initial values + ran_state_u32 s = ran_init(); + for(uint64_t i = blockIdx.x * blockDim.x + threadIdx.x; + i < buffer_element_size / ck::packed_size_v; + i += blockDim.x * gridDim.x) + { + if constexpr(ck::is_same_v) + { + float u1 = + ran_gen_round_u32(s) * (1.0f / 4294967296.0f) * (max_value - min_value) + min_value; + float u2 = + ran_gen_round_u32(s) * (1.0f / 4294967296.0f) * (max_value - min_value) + min_value; + + p[i] = ck::type_convert(ck::float2_t{u1, u2}); + } + else + { + float ran = ran_gen_round_u32(s) * (1.0f / 4294967296.0f); + p[i] = ck::type_convert(ran * (max_value - min_value) + min_value); + } + } +} + +template +__global__ void +fill_tensor_norm_rand_fp_values(T* p, float sigma, float mean, uint64_t buffer_element_size) +{ + // initial values + ran_state_u32 s = ran_init(); + float norm[2]; + for(uint64_t i = blockIdx.x * blockDim.x + threadIdx.x, j = 0; i < buffer_element_size; + i += blockDim.x * gridDim.x, j++) + { + if(j % (2 / ck::packed_size_v) == 0) + { + float u1 = ran_gen_round_u32(s) * (1.0f / 4294967296.0f); + float u2 = ran_gen_round_u32(s) * (1.0f / 4294967296.0f); + norm[0] = + sigma * std::sqrt(-2.0f * ck::math::log(u1)) * std::cos(2.0f * M_PI * u2) + mean; + norm[1] = + sigma * std::sqrt(-2.0f * ck::math::log(u1)) * std::sin(2.0f * M_PI * u2) + mean; + } + + if constexpr(ck::is_same_v) + { + p[i] = ck::type_convert(ck::float2_t{norm[0], norm[1]}); + } + else + { + p[i] = ck::type_convert(norm[j % 2]); + } + } +} diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 81d1ed4063..81e893edf5 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -300,6 +300,7 @@ add_subdirectory(transpose) add_subdirectory(permute_scale) add_subdirectory(wrapper) add_subdirectory(quantization) +add_subdirectory(device_memory) if(SUPPORTED_GPU_TARGETS MATCHES "gfx11") add_subdirectory(wmma_op) endif() diff --git a/test/device_memory/CMakeLists.txt b/test/device_memory/CMakeLists.txt new file mode 100644 index 0000000000..b2c8ab273c --- /dev/null +++ b/test/device_memory/CMakeLists.txt @@ -0,0 +1,7 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +add_custom_target(device_mem_tests) +add_gtest_executable(test_device_prng test_device_prng.cpp) +target_link_libraries(test_device_prng PRIVATE utility) +add_dependencies(test_device_prng device_mem_tests) diff --git a/test/device_memory/test_device_prng.cpp b/test/device_memory/test_device_prng.cpp new file mode 100644 index 0000000000..39fa77237d --- /dev/null +++ b/test/device_memory/test_device_prng.cpp @@ -0,0 +1,227 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include +#include +#include +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/ck.hpp" + +template +void convertTypeFromDevice(std::vector& fromDevice, + std::vector& res, + uint64_t num_elements) +{ + for(uint64_t i = 0; i < num_elements / ck::packed_size_v; i++) + { + // since the CPU dosen't have non-standard data types, we need to convert to float + if constexpr(ck::is_same_v, ck::f4x2_pk_t>) + { + ck::float2_t tmp = ck::type_convert(fromDevice[i]); + res[i * 2] = tmp.x; + res[i * 2 + 1] = tmp.y; + } + else if constexpr(ck::is_same_v, ck::pk_i4_t>) + { + uint8_t packed = fromDevice[i].data; + + int hi = (packed >> 4) & 0x0f; + int lo = packed & 0x0f; + res[i * 2] = static_cast(hi - 8); + res[i * 2 + 1] = static_cast(lo - 8); + } + else + { + res[i] = ck::type_convert(fromDevice[i]); + } + } +} + +template +void TDevRanUniGenInt(int min_val, int max_val, uint64_t num_elements) +{ + + size_t packed_size = ck::packed_size_v; + + ck::DeviceMem test_buf(sizeof(T) * num_elements / packed_size); + std::vector from_host(num_elements / packed_size); + std::vector host_elements(num_elements); + + test_buf.FillUniformRandInteger(min_val, max_val); + test_buf.FromDevice(&from_host[0]); + + uint64_t num_equal = 0; + bool in_range = true; + + convertTypeFromDevice(from_host, host_elements, num_elements); + // very basic checks: check if all data points are in range and + // hf data is within 6 sigma of expected value + for(uint64_t i = 0; i < num_elements; i++) + { + if(host_elements[i] >= max_val || host_elements[i] < min_val) + { + in_range = false; + } + if(i > 0) + { + if(host_elements[i] == host_elements[i - 1]) + { + num_equal++; + } + } + } + EXPECT_TRUE(in_range); + + double expected_mean = + (static_cast(num_elements) - 1.0) / (static_cast(max_val - min_val)); + double std_dev = std::sqrt(expected_mean); + double upper_bound = expected_mean + 6 * std_dev; + double lower_bound = expected_mean - 6 * std_dev; + + // in these cases the test parameters are unsuitable + EXPECT_TRUE(lower_bound > 1.0); + EXPECT_TRUE(upper_bound < static_cast(num_elements) - 2.0); + + // printf("lower bound: %f upper bound: %f actual: %d\n", + // lower_bound, + // upper_bound, + // static_cast(num_equal)); + EXPECT_TRUE(static_cast(num_equal) > lower_bound); + EXPECT_TRUE(static_cast(num_equal) < upper_bound); +} + +template +void TDevRanUniGenFp(double min_val, + double max_val, + uint64_t num_elements, + double std_err_tolerance = 6.0) +{ + size_t packed_size = ck::packed_size_v; + ck::DeviceMem test_buf(sizeof(T) * num_elements / packed_size); + std::vector host_buf(num_elements / packed_size); + std::vector host_elements(num_elements); + + test_buf.FillUniformRandFp(min_val, max_val); + test_buf.FromDevice(&host_buf[0]); + + bool in_range = true; + double accum_mean = 0.0; + double accum_variance = 0.0; + + // #kabraham: with floats, we can actually do some more extensive tests, + // compute mean, std_dev and std_err and compare these to expected values + convertTypeFromDevice(host_buf, host_elements, num_elements); + for(uint64_t i = 0; i < num_elements; i++) + { + if(host_elements[i] > max_val || host_elements[i] < min_val) + { + in_range = false; + } + accum_mean += host_elements[i]; + } + EXPECT_TRUE(in_range); + EXPECT_TRUE(accum_mean != 0.0); + double mean = accum_mean / num_elements; + + for(uint64_t i = 0; i < num_elements; i++) + { + accum_variance += std::pow(host_elements[i] - mean, 2); + } + double std_dev = std::sqrt(accum_variance) / num_elements; + + double expected_mean = (min_val + max_val) / 2.0; + double expected_std_dev = (max_val - min_val) / std::sqrt(12 * num_elements); + double std_err = expected_std_dev / sqrt(num_elements); + // printf( + // "Expected: mean: %f std_dev: %f std_err : %f\n", expected_mean, expected_std_dev, + // std_err); + // printf(" Actual: mean: %f std_dev: %f \n", mean, std_dev); + EXPECT_TRUE(abs(mean - expected_mean) < 6 * expected_std_dev); + EXPECT_TRUE(abs(std_dev - expected_std_dev) < std_err_tolerance * std_err); +} + +template +void TDevRanNormGenFp(double sigma, + double mean, + uint64_t num_elements, + double ERRF_BUCKET_SIZE = 0.1, + double ERRF_BUCKET_RANGE = 3.0, + double sig_tolerence = 6.0) +{ + ck::DeviceMem test_buf(sizeof(T) * num_elements); + std::vector host_buf(num_elements); + std::vector host_elements(num_elements); + + test_buf.FillNormalRandFp(sigma, mean); + test_buf.FromDevice(&host_buf[0]); + + convertTypeFromDevice(host_buf, host_elements, num_elements); + + // #kabraham: compute errf buckets and compare with expected vaules + int ERRF_NUM_BUCKETS = 2 * ERRF_BUCKET_RANGE / ERRF_BUCKET_SIZE + 1; + + std::vector errf_buckets(ERRF_NUM_BUCKETS, 0); + for(uint64_t i = 0; i < num_elements; i++) + { + for(int bucket = 0; bucket < ERRF_NUM_BUCKETS; bucket++) + { + // #kabraham: count exact hits as half (kind of relevant for utra-low-precision formats) + if(host_elements[i] < sigma * (-ERRF_BUCKET_RANGE + bucket * ERRF_BUCKET_SIZE) + mean) + { + errf_buckets[bucket] += 2; + } + else if(host_elements[i] <= + sigma * (-ERRF_BUCKET_RANGE + bucket * ERRF_BUCKET_SIZE) + mean) + { + errf_buckets[bucket] += 1; + } + } + } + + for(int bucket = 0; bucket < ERRF_NUM_BUCKETS; bucket++) + { + double expected_num_entries = + (std::erfc((ERRF_BUCKET_RANGE - bucket * ERRF_BUCKET_SIZE) / std::sqrt(2))) * 0.5 * + num_elements; + double noise_range = std::sqrt(expected_num_entries); + // printf("Expected for bucket %d: %d. Actual: %d \n", + // bucket, + // static_cast(expected_num_entries), + // static_cast(errf_buckets[bucket] / 2)); + EXPECT_TRUE(errf_buckets[bucket] / 2 >= expected_num_entries - sig_tolerence * noise_range); + EXPECT_TRUE(errf_buckets[bucket] / 2 <= expected_num_entries + sig_tolerence * noise_range); + } +} + +TEST(TDevIntegerRanUniGen, U8) { TDevRanUniGenInt(0, 2, 15000); } +TEST(TDevIntegerRanUniGen, U16) { TDevRanUniGenInt(0, 100, 100000); } +TEST(TDevIntegerRanUniGen, U32) { TDevRanUniGenInt(0, 10000, 10000000); } +TEST(TDevIntegerRanUniGen, I4) { TDevRanUniGenInt(-2, 2, 10000000); } + +TEST(TDevIntegerRanUniGen, F32) { TDevRanUniGenInt(-2, 2, 10000000); } +TEST(TDevIntegerRanUniGen, F16) { TDevRanUniGenInt(-2, 2, 1000000); } + +TEST(TDevFpRanUniGen, F32_1) { TDevRanUniGenFp(0, 1, 100000); } +TEST(TDevFpRanUniGen, F32_2) { TDevRanUniGenFp(0, 37, 73000); } +TEST(TDevFpRanUniGen, F32_3) { TDevRanUniGenFp(-2, 1, 84000); } + +TEST(TDevFpRanUniGen, F16) { TDevRanUniGenFp(-1, 1, 100000); } +TEST(TDevFpRanUniGen, BF16) { TDevRanUniGenFp(0, 2, 100000); } +TEST(TDevFpRanUniGen, F8) { TDevRanUniGenFp(0, 2, 100000); } +TEST(TDevFpRanUniGen, BF8) { TDevRanUniGenFp(-5, 5, 100000); } +TEST(TDevFpRanUniGen, F4) { TDevRanUniGenFp(-5, 5, 100000, 20.0); } + +TEST(TDevRanNormGenFp, F32_1) { TDevRanNormGenFp(1, 0, 1000000); } +TEST(TDevRanNormGenFp, F32_2) { TDevRanNormGenFp(5, -2, 10000000, 0.2, 5.0); } + +TEST(TDevRanNormGenFp, F16) { TDevRanNormGenFp(5, -2, 100000); } +TEST(TDevRanNormGenFp, BF16) { TDevRanNormGenFp(5, -2, 100000); } + +TEST(TDevRanNormGenFp, F8) { TDevRanNormGenFp(2, 0, 100000, 0.5, 2.0, 10.0); } +TEST(TDevRanNormGenFp, BF8) { TDevRanNormGenFp(16, 0, 100000, 0.5, 2.0, 30.0); } + +TEST(TDevRanNormGenFp, F4) { TDevRanNormGenFp(2, 0, 100000, 0.5, 3.0, 30.0); }