mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK_Builder] [testing] Integrate device random generators (#3427)
Implemented device random number generators for ck tensors. Includes tests and integration to ck builder testing interface.
This commit is contained in:
@@ -9,6 +9,7 @@
|
||||
#include "ck_tile/builder/testing/testing.hpp"
|
||||
#include "ck_tile/builder/testing/extent.hpp"
|
||||
#include "ck_tile/builder/testing/tensor_buffer.hpp"
|
||||
#include "ck_tile/builder/testing/tensor_initialization.hpp"
|
||||
#include "ck/library/utility/convolution_parameter.hpp"
|
||||
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
|
||||
/// This file implements common functionality for invoking/testing grouped
|
||||
@@ -238,6 +239,20 @@ UniqueInputs<SIGNATURE> alloc_inputs(const Args<SIGNATURE>& args)
|
||||
};
|
||||
}
|
||||
|
||||
/// @brief `init_inputs()` specialization for forward convolution.
|
||||
///
|
||||
/// @tparam SIGNATURE Forward convolution signature.
|
||||
///
|
||||
/// @see alloc_inputs()
|
||||
template <auto SIGNATURE>
|
||||
requires ValidConvSignature<SIGNATURE> && ConvDirectionIsForward<SIGNATURE> &&
|
||||
ValidUniqueInputs<SIGNATURE>
|
||||
void init_inputs(const Args<SIGNATURE>& args, UniqueInputs<SIGNATURE>& inputs)
|
||||
{
|
||||
init_tensor_buffer_uniform_fp(inputs.input_buf, args.make_input_descriptor(), -2.0f, 2.0f);
|
||||
init_tensor_buffer_uniform_fp(inputs.weight_buf, args.make_weight_descriptor(), -2.0f, 2.0f);
|
||||
}
|
||||
|
||||
/// @brief `alloc_outputs()` specialization for forward convolution.
|
||||
///
|
||||
/// @tparam SIGNATURE Forward convolution signature.
|
||||
|
||||
@@ -0,0 +1,82 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <stdexcept>
|
||||
#include <memory>
|
||||
#include <numeric>
|
||||
#include <span>
|
||||
#include <concepts>
|
||||
#include <hip/hip_runtime.h>
|
||||
#include "ck_tile/builder/conv_signature_concepts.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
|
||||
#include "ck_tile/builder/testing/type_traits.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
|
||||
#include "ck/library/utility/device_tensor_generator.hpp"
|
||||
|
||||
namespace ck_tile::builder::test {
|
||||
|
||||
template <DataType DT>
|
||||
void init_tensor_buffer_uniform_int(const DeviceBuffer& buf,
|
||||
const TensorDescriptor<DT>& descriptor,
|
||||
int min_val,
|
||||
int max_val)
|
||||
{
|
||||
size_t size = descriptor.get_element_space_size_in_bytes();
|
||||
|
||||
if(max_val - min_val <= 1)
|
||||
{
|
||||
throw std::runtime_error("Error while filling device tensor with random integer data: max "
|
||||
"value must be at least 2 greater than min value, otherwise "
|
||||
"tensor will be filled by a constant value (end is exclusive)");
|
||||
}
|
||||
|
||||
using ck_type = factory::internal::DataTypeToCK<DT>::type;
|
||||
|
||||
// we might be asked to generate int values on fp data types that don't have the required
|
||||
// precision
|
||||
if(static_cast<ck_type>(max_val - 1) == static_cast<ck_type>(min_val))
|
||||
{
|
||||
throw std::runtime_error("Error while filling device tensor with random integer data: "
|
||||
"insufficient precision in specified range");
|
||||
}
|
||||
size_t packed_size = ck::packed_size_v<ck_type>;
|
||||
fill_tensor_uniform_rand_int_values<<<256, 256>>>(
|
||||
static_cast<ck_type>(buf.get()), min_val, max_val, (size * packed_size) / sizeof(ck_type));
|
||||
}
|
||||
|
||||
template <DataType DT>
|
||||
void init_tensor_buffer_uniform_fp(const DeviceBuffer& buf,
|
||||
const TensorDescriptor<DT>& descriptor,
|
||||
float min_value,
|
||||
float max_value)
|
||||
{
|
||||
size_t size = descriptor.get_element_space_size_in_bytes();
|
||||
|
||||
using ck_type = factory::internal::DataTypeToCK<DT>::type;
|
||||
|
||||
size_t packed_size = ck::packed_size_v<ck_type>;
|
||||
fill_tensor_uniform_rand_fp_values<<<256, 256>>>(reinterpret_cast<ck_type*>(buf.get()),
|
||||
min_value,
|
||||
max_value,
|
||||
(size * packed_size) / sizeof(ck_type));
|
||||
}
|
||||
|
||||
template <DataType DT>
|
||||
void init_tensor_buffer_normal_fp(const DeviceBuffer& buf,
|
||||
const TensorDescriptor<DT>& descriptor,
|
||||
float sigma,
|
||||
float mean)
|
||||
{
|
||||
size_t size = descriptor.get_element_space_size_in_bytes();
|
||||
|
||||
using ck_type = factory::internal::DataTypeToCK<DT>::type;
|
||||
size_t packed_size = ck::packed_size_v<ck_type>;
|
||||
fill_tensor_norm_rand_fp_values<<<256, 256>>>(
|
||||
static_cast<ck_type*>(buf.get()), sigma, mean, (size * packed_size) / sizeof(ck_type));
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::test
|
||||
@@ -205,6 +205,20 @@ template <auto SIGNATURE>
|
||||
requires ValidUniqueInputs<SIGNATURE>
|
||||
UniqueInputs<SIGNATURE> alloc_inputs(const Args<SIGNATURE>& args);
|
||||
|
||||
/// @brief Allocate inputs corresponding to a signature.
|
||||
///
|
||||
/// The `init_inputs()` function is used to initialize pseudo-random data
|
||||
/// to the tensors specified in the Inputs structure.
|
||||
///
|
||||
/// @tparam SIGNATURE the signature to specialize the structure for.
|
||||
///
|
||||
/// @see Inputs
|
||||
/// @see UniqueInputs
|
||||
/// @see tensor_initialization
|
||||
template <auto SIGNATURE>
|
||||
requires ValidUniqueInputs<SIGNATURE>
|
||||
void init_inputs(const Args<SIGNATURE>& args, UniqueInputs<SIGNATURE>& inputs);
|
||||
|
||||
/// @brief Allocate outputs corresponding to a signature.
|
||||
///
|
||||
/// The `alloc_outputs()` function is used to create an instance of
|
||||
|
||||
@@ -81,6 +81,8 @@ TEST(Fwd2DFp16_CShufV3_GNHWC, EndToEnd)
|
||||
auto inputs = alloc_inputs(args);
|
||||
auto outputs = alloc_outputs(args);
|
||||
|
||||
init_inputs(args, inputs);
|
||||
|
||||
auto conv = Instance{};
|
||||
ckt::run(conv, args, inputs.get(), outputs.get());
|
||||
}
|
||||
|
||||
@@ -4,6 +4,8 @@
|
||||
#pragma once
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <math.h>
|
||||
#include "ck/library/utility/device_tensor_generator.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -34,6 +36,12 @@ struct DeviceMem
|
||||
void SetZero() const;
|
||||
template <typename T>
|
||||
void SetValue(T x) const;
|
||||
template <typename T>
|
||||
void FillUniformRandInteger(int min_value, int max_value);
|
||||
template <typename T>
|
||||
void FillUniformRandFp(float min_value, float max_value);
|
||||
template <typename T>
|
||||
void FillNormalRandFp(float sigma, float mean);
|
||||
~DeviceMem();
|
||||
|
||||
void* mpDeviceBuf;
|
||||
@@ -51,4 +59,48 @@ void DeviceMem::SetValue(T x) const
|
||||
set_buffer_value<T><<<1, 1024>>>(static_cast<T*>(mpDeviceBuf), x, mMemSize / sizeof(T));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void DeviceMem::FillUniformRandInteger(int min_value, int max_value)
|
||||
{
|
||||
if(mMemSize % sizeof(T) != 0)
|
||||
{
|
||||
throw std::runtime_error("wrong! not entire DeviceMem will be filled");
|
||||
}
|
||||
if(max_value - min_value <= 1)
|
||||
{
|
||||
throw std::runtime_error("Error while filling device tensor with random integer data: max "
|
||||
"value must be at least 2 greater than min value, otherwise "
|
||||
"tensor will be filled by a constant value (end is exclusive)");
|
||||
}
|
||||
if(max_value - 1 == min_value || max_value - 1 == max_value)
|
||||
{
|
||||
throw std::runtime_error("Error while filling device tensor with random integer data: "
|
||||
"insufficient precision in specified range");
|
||||
}
|
||||
|
||||
size_t packed_size = packed_size_v<T>;
|
||||
fill_tensor_uniform_rand_int_values<<<256, 256>>>(
|
||||
static_cast<T*>(mpDeviceBuf), min_value, max_value, (mMemSize * packed_size) / sizeof(T));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void DeviceMem::FillUniformRandFp(float min_value, float max_value)
|
||||
{
|
||||
if(mMemSize % sizeof(T) != 0)
|
||||
{
|
||||
throw std::runtime_error("wrong! not entire DeviceMem will be filled");
|
||||
}
|
||||
|
||||
size_t packed_size = packed_size_v<T>;
|
||||
fill_tensor_uniform_rand_fp_values<<<256, 256>>>(
|
||||
static_cast<T*>(mpDeviceBuf), min_value, max_value, (mMemSize * packed_size) / sizeof(T));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void DeviceMem::FillNormalRandFp(float sigma, float mean)
|
||||
{
|
||||
|
||||
fill_tensor_norm_rand_fp_values<<<256, 256>>>(
|
||||
static_cast<T*>(mpDeviceBuf), sigma, mean, mMemSize / sizeof(T));
|
||||
}
|
||||
} // namespace ck
|
||||
|
||||
135
include/ck/library/utility/device_tensor_generator.hpp
Normal file
135
include/ck/library/utility/device_tensor_generator.hpp
Normal file
@@ -0,0 +1,135 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
#pragma once
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/library/utility/device_tensor_generator.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include <cmath>
|
||||
|
||||
// use xorshift for now since it is simple. Should be suitable enough, but feel free to switch in
|
||||
// the future
|
||||
struct ran_state_u32
|
||||
{
|
||||
uint32_t s[4];
|
||||
};
|
||||
|
||||
__device__ uint32_t ran_gen_round_u32(ran_state_u32& state)
|
||||
{
|
||||
uint32_t tmp = state.s[3];
|
||||
state.s[3] = state.s[2];
|
||||
state.s[2] = state.s[1];
|
||||
state.s[1] = state.s[0];
|
||||
tmp ^= tmp << 11;
|
||||
tmp ^= tmp >> 8;
|
||||
state.s[0] = tmp ^ state.s[0] ^ (state.s[0] >> 19);
|
||||
return state.s[0];
|
||||
}
|
||||
|
||||
__device__ ran_state_u32 ran_init(uint32_t seed = 0)
|
||||
{
|
||||
ran_state_u32 state;
|
||||
// use primes for initialization
|
||||
state.s[0] = (blockDim.x * blockIdx.x + threadIdx.x) * 8912741 + 2313212 + seed;
|
||||
state.s[1] =
|
||||
(gridDim.x * blockDim.x - (blockDim.x * blockIdx.x + threadIdx.x)) * 5013829 + 6012697;
|
||||
state.s[2] = (blockDim.x * blockIdx.x + threadIdx.x) * 3412309 + 2912479;
|
||||
state.s[3] =
|
||||
(gridDim.x * blockDim.x - (blockDim.x * blockIdx.x + threadIdx.x)) * 1001447 + 9912307;
|
||||
|
||||
// run 20 rounds
|
||||
for(int i = 0; i < 20; i++)
|
||||
{
|
||||
ran_gen_round_u32(state);
|
||||
}
|
||||
return state;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void fill_tensor_uniform_rand_int_values(T* p,
|
||||
int min_value,
|
||||
int max_value,
|
||||
uint64_t buffer_element_size)
|
||||
{
|
||||
// initial values
|
||||
ran_state_u32 s = ran_init();
|
||||
for(uint64_t i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
i < buffer_element_size / ck::packed_size_v<T>;
|
||||
i += blockDim.x * gridDim.x)
|
||||
{
|
||||
if constexpr(ck::is_same_v<T, ck::pk_i4_t>)
|
||||
{
|
||||
uint8_t hi = ((ran_gen_round_u32(s)) % (max_value - min_value)) + min_value + 8;
|
||||
uint8_t lo = ((ran_gen_round_u32(s)) % (max_value - min_value)) + min_value + 8;
|
||||
ck::pk_i4_t res = ((hi & 0xf) << 4) + (lo & 0xf);
|
||||
p[i] = res;
|
||||
}
|
||||
else
|
||||
{
|
||||
p[i] = ck::type_convert<T, int>(
|
||||
static_cast<int>((ran_gen_round_u32(s)) % (max_value - min_value)) + min_value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void fill_tensor_uniform_rand_fp_values(T* p,
|
||||
float min_value,
|
||||
float max_value,
|
||||
uint64_t buffer_element_size)
|
||||
{
|
||||
// initial values
|
||||
ran_state_u32 s = ran_init();
|
||||
for(uint64_t i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
i < buffer_element_size / ck::packed_size_v<T>;
|
||||
i += blockDim.x * gridDim.x)
|
||||
{
|
||||
if constexpr(ck::is_same_v<T, ck::f4x2_pk_t>)
|
||||
{
|
||||
float u1 =
|
||||
ran_gen_round_u32(s) * (1.0f / 4294967296.0f) * (max_value - min_value) + min_value;
|
||||
float u2 =
|
||||
ran_gen_round_u32(s) * (1.0f / 4294967296.0f) * (max_value - min_value) + min_value;
|
||||
|
||||
p[i] = ck::type_convert<ck::f4x2_t>(ck::float2_t{u1, u2});
|
||||
}
|
||||
else
|
||||
{
|
||||
float ran = ran_gen_round_u32(s) * (1.0f / 4294967296.0f);
|
||||
p[i] = ck::type_convert<T, float>(ran * (max_value - min_value) + min_value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void
|
||||
fill_tensor_norm_rand_fp_values(T* p, float sigma, float mean, uint64_t buffer_element_size)
|
||||
{
|
||||
// initial values
|
||||
ran_state_u32 s = ran_init();
|
||||
float norm[2];
|
||||
for(uint64_t i = blockIdx.x * blockDim.x + threadIdx.x, j = 0; i < buffer_element_size;
|
||||
i += blockDim.x * gridDim.x, j++)
|
||||
{
|
||||
if(j % (2 / ck::packed_size_v<T>) == 0)
|
||||
{
|
||||
float u1 = ran_gen_round_u32(s) * (1.0f / 4294967296.0f);
|
||||
float u2 = ran_gen_round_u32(s) * (1.0f / 4294967296.0f);
|
||||
norm[0] =
|
||||
sigma * std::sqrt(-2.0f * ck::math::log(u1)) * std::cos(2.0f * M_PI * u2) + mean;
|
||||
norm[1] =
|
||||
sigma * std::sqrt(-2.0f * ck::math::log(u1)) * std::sin(2.0f * M_PI * u2) + mean;
|
||||
}
|
||||
|
||||
if constexpr(ck::is_same_v<T, ck::f4x2_pk_t>)
|
||||
{
|
||||
p[i] = ck::type_convert<ck::f4x2_t>(ck::float2_t{norm[0], norm[1]});
|
||||
}
|
||||
else
|
||||
{
|
||||
p[i] = ck::type_convert<T, float>(norm[j % 2]);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -300,6 +300,7 @@ add_subdirectory(transpose)
|
||||
add_subdirectory(permute_scale)
|
||||
add_subdirectory(wrapper)
|
||||
add_subdirectory(quantization)
|
||||
add_subdirectory(device_memory)
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx11")
|
||||
add_subdirectory(wmma_op)
|
||||
endif()
|
||||
|
||||
7
test/device_memory/CMakeLists.txt
Normal file
7
test/device_memory/CMakeLists.txt
Normal file
@@ -0,0 +1,7 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
add_custom_target(device_mem_tests)
|
||||
add_gtest_executable(test_device_prng test_device_prng.cpp)
|
||||
target_link_libraries(test_device_prng PRIVATE utility)
|
||||
add_dependencies(test_device_prng device_mem_tests)
|
||||
227
test/device_memory/test_device_prng.cpp
Normal file
227
test/device_memory/test_device_prng.cpp
Normal file
@@ -0,0 +1,227 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <cmath>
|
||||
#include <vector>
|
||||
#include <gtest/gtest.h>
|
||||
#include <random>
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/ck.hpp"
|
||||
|
||||
template <typename inType, typename outType>
|
||||
void convertTypeFromDevice(std::vector<inType>& fromDevice,
|
||||
std::vector<outType>& res,
|
||||
uint64_t num_elements)
|
||||
{
|
||||
for(uint64_t i = 0; i < num_elements / ck::packed_size_v<inType>; i++)
|
||||
{
|
||||
// since the CPU dosen't have non-standard data types, we need to convert to float
|
||||
if constexpr(ck::is_same_v<ck::remove_cvref_t<inType>, ck::f4x2_pk_t>)
|
||||
{
|
||||
ck::float2_t tmp = ck::type_convert<ck::float2_t, ck::f4x2_t>(fromDevice[i]);
|
||||
res[i * 2] = tmp.x;
|
||||
res[i * 2 + 1] = tmp.y;
|
||||
}
|
||||
else if constexpr(ck::is_same_v<ck::remove_cvref_t<inType>, ck::pk_i4_t>)
|
||||
{
|
||||
uint8_t packed = fromDevice[i].data;
|
||||
|
||||
int hi = (packed >> 4) & 0x0f;
|
||||
int lo = packed & 0x0f;
|
||||
res[i * 2] = static_cast<outType>(hi - 8);
|
||||
res[i * 2 + 1] = static_cast<outType>(lo - 8);
|
||||
}
|
||||
else
|
||||
{
|
||||
res[i] = ck::type_convert<outType, inType>(fromDevice[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void TDevRanUniGenInt(int min_val, int max_val, uint64_t num_elements)
|
||||
{
|
||||
|
||||
size_t packed_size = ck::packed_size_v<T>;
|
||||
|
||||
ck::DeviceMem test_buf(sizeof(T) * num_elements / packed_size);
|
||||
std::vector<T> from_host(num_elements / packed_size);
|
||||
std::vector<int> host_elements(num_elements);
|
||||
|
||||
test_buf.FillUniformRandInteger<T>(min_val, max_val);
|
||||
test_buf.FromDevice(&from_host[0]);
|
||||
|
||||
uint64_t num_equal = 0;
|
||||
bool in_range = true;
|
||||
|
||||
convertTypeFromDevice(from_host, host_elements, num_elements);
|
||||
// very basic checks: check if all data points are in range and
|
||||
// hf data is within 6 sigma of expected value
|
||||
for(uint64_t i = 0; i < num_elements; i++)
|
||||
{
|
||||
if(host_elements[i] >= max_val || host_elements[i] < min_val)
|
||||
{
|
||||
in_range = false;
|
||||
}
|
||||
if(i > 0)
|
||||
{
|
||||
if(host_elements[i] == host_elements[i - 1])
|
||||
{
|
||||
num_equal++;
|
||||
}
|
||||
}
|
||||
}
|
||||
EXPECT_TRUE(in_range);
|
||||
|
||||
double expected_mean =
|
||||
(static_cast<double>(num_elements) - 1.0) / (static_cast<double>(max_val - min_val));
|
||||
double std_dev = std::sqrt(expected_mean);
|
||||
double upper_bound = expected_mean + 6 * std_dev;
|
||||
double lower_bound = expected_mean - 6 * std_dev;
|
||||
|
||||
// in these cases the test parameters are unsuitable
|
||||
EXPECT_TRUE(lower_bound > 1.0);
|
||||
EXPECT_TRUE(upper_bound < static_cast<double>(num_elements) - 2.0);
|
||||
|
||||
// printf("lower bound: %f upper bound: %f actual: %d\n",
|
||||
// lower_bound,
|
||||
// upper_bound,
|
||||
// static_cast<int>(num_equal));
|
||||
EXPECT_TRUE(static_cast<double>(num_equal) > lower_bound);
|
||||
EXPECT_TRUE(static_cast<double>(num_equal) < upper_bound);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void TDevRanUniGenFp(double min_val,
|
||||
double max_val,
|
||||
uint64_t num_elements,
|
||||
double std_err_tolerance = 6.0)
|
||||
{
|
||||
size_t packed_size = ck::packed_size_v<T>;
|
||||
ck::DeviceMem test_buf(sizeof(T) * num_elements / packed_size);
|
||||
std::vector<T> host_buf(num_elements / packed_size);
|
||||
std::vector<float> host_elements(num_elements);
|
||||
|
||||
test_buf.FillUniformRandFp<T>(min_val, max_val);
|
||||
test_buf.FromDevice(&host_buf[0]);
|
||||
|
||||
bool in_range = true;
|
||||
double accum_mean = 0.0;
|
||||
double accum_variance = 0.0;
|
||||
|
||||
// #kabraham: with floats, we can actually do some more extensive tests,
|
||||
// compute mean, std_dev and std_err and compare these to expected values
|
||||
convertTypeFromDevice(host_buf, host_elements, num_elements);
|
||||
for(uint64_t i = 0; i < num_elements; i++)
|
||||
{
|
||||
if(host_elements[i] > max_val || host_elements[i] < min_val)
|
||||
{
|
||||
in_range = false;
|
||||
}
|
||||
accum_mean += host_elements[i];
|
||||
}
|
||||
EXPECT_TRUE(in_range);
|
||||
EXPECT_TRUE(accum_mean != 0.0);
|
||||
double mean = accum_mean / num_elements;
|
||||
|
||||
for(uint64_t i = 0; i < num_elements; i++)
|
||||
{
|
||||
accum_variance += std::pow(host_elements[i] - mean, 2);
|
||||
}
|
||||
double std_dev = std::sqrt(accum_variance) / num_elements;
|
||||
|
||||
double expected_mean = (min_val + max_val) / 2.0;
|
||||
double expected_std_dev = (max_val - min_val) / std::sqrt(12 * num_elements);
|
||||
double std_err = expected_std_dev / sqrt(num_elements);
|
||||
// printf(
|
||||
// "Expected: mean: %f std_dev: %f std_err : %f\n", expected_mean, expected_std_dev,
|
||||
// std_err);
|
||||
// printf(" Actual: mean: %f std_dev: %f \n", mean, std_dev);
|
||||
EXPECT_TRUE(abs(mean - expected_mean) < 6 * expected_std_dev);
|
||||
EXPECT_TRUE(abs(std_dev - expected_std_dev) < std_err_tolerance * std_err);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void TDevRanNormGenFp(double sigma,
|
||||
double mean,
|
||||
uint64_t num_elements,
|
||||
double ERRF_BUCKET_SIZE = 0.1,
|
||||
double ERRF_BUCKET_RANGE = 3.0,
|
||||
double sig_tolerence = 6.0)
|
||||
{
|
||||
ck::DeviceMem test_buf(sizeof(T) * num_elements);
|
||||
std::vector<T> host_buf(num_elements);
|
||||
std::vector<float> host_elements(num_elements);
|
||||
|
||||
test_buf.FillNormalRandFp<T>(sigma, mean);
|
||||
test_buf.FromDevice(&host_buf[0]);
|
||||
|
||||
convertTypeFromDevice(host_buf, host_elements, num_elements);
|
||||
|
||||
// #kabraham: compute errf buckets and compare with expected vaules
|
||||
int ERRF_NUM_BUCKETS = 2 * ERRF_BUCKET_RANGE / ERRF_BUCKET_SIZE + 1;
|
||||
|
||||
std::vector<int64_t> errf_buckets(ERRF_NUM_BUCKETS, 0);
|
||||
for(uint64_t i = 0; i < num_elements; i++)
|
||||
{
|
||||
for(int bucket = 0; bucket < ERRF_NUM_BUCKETS; bucket++)
|
||||
{
|
||||
// #kabraham: count exact hits as half (kind of relevant for utra-low-precision formats)
|
||||
if(host_elements[i] < sigma * (-ERRF_BUCKET_RANGE + bucket * ERRF_BUCKET_SIZE) + mean)
|
||||
{
|
||||
errf_buckets[bucket] += 2;
|
||||
}
|
||||
else if(host_elements[i] <=
|
||||
sigma * (-ERRF_BUCKET_RANGE + bucket * ERRF_BUCKET_SIZE) + mean)
|
||||
{
|
||||
errf_buckets[bucket] += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for(int bucket = 0; bucket < ERRF_NUM_BUCKETS; bucket++)
|
||||
{
|
||||
double expected_num_entries =
|
||||
(std::erfc((ERRF_BUCKET_RANGE - bucket * ERRF_BUCKET_SIZE) / std::sqrt(2))) * 0.5 *
|
||||
num_elements;
|
||||
double noise_range = std::sqrt(expected_num_entries);
|
||||
// printf("Expected for bucket %d: %d. Actual: %d \n",
|
||||
// bucket,
|
||||
// static_cast<int>(expected_num_entries),
|
||||
// static_cast<int>(errf_buckets[bucket] / 2));
|
||||
EXPECT_TRUE(errf_buckets[bucket] / 2 >= expected_num_entries - sig_tolerence * noise_range);
|
||||
EXPECT_TRUE(errf_buckets[bucket] / 2 <= expected_num_entries + sig_tolerence * noise_range);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(TDevIntegerRanUniGen, U8) { TDevRanUniGenInt<uint8_t>(0, 2, 15000); }
|
||||
TEST(TDevIntegerRanUniGen, U16) { TDevRanUniGenInt<uint16_t>(0, 100, 100000); }
|
||||
TEST(TDevIntegerRanUniGen, U32) { TDevRanUniGenInt<uint32_t>(0, 10000, 10000000); }
|
||||
TEST(TDevIntegerRanUniGen, I4) { TDevRanUniGenInt<ck::pk_i4_t>(-2, 2, 10000000); }
|
||||
|
||||
TEST(TDevIntegerRanUniGen, F32) { TDevRanUniGenInt<float>(-2, 2, 10000000); }
|
||||
TEST(TDevIntegerRanUniGen, F16) { TDevRanUniGenInt<ck::half_t>(-2, 2, 1000000); }
|
||||
|
||||
TEST(TDevFpRanUniGen, F32_1) { TDevRanUniGenFp<float>(0, 1, 100000); }
|
||||
TEST(TDevFpRanUniGen, F32_2) { TDevRanUniGenFp<float>(0, 37, 73000); }
|
||||
TEST(TDevFpRanUniGen, F32_3) { TDevRanUniGenFp<float>(-2, 1, 84000); }
|
||||
|
||||
TEST(TDevFpRanUniGen, F16) { TDevRanUniGenFp<ck::half_t>(-1, 1, 100000); }
|
||||
TEST(TDevFpRanUniGen, BF16) { TDevRanUniGenFp<ck::bhalf_t>(0, 2, 100000); }
|
||||
TEST(TDevFpRanUniGen, F8) { TDevRanUniGenFp<ck::f8_t>(0, 2, 100000); }
|
||||
TEST(TDevFpRanUniGen, BF8) { TDevRanUniGenFp<ck::bf8_t>(-5, 5, 100000); }
|
||||
TEST(TDevFpRanUniGen, F4) { TDevRanUniGenFp<ck::f4x2_pk_t>(-5, 5, 100000, 20.0); }
|
||||
|
||||
TEST(TDevRanNormGenFp, F32_1) { TDevRanNormGenFp<float>(1, 0, 1000000); }
|
||||
TEST(TDevRanNormGenFp, F32_2) { TDevRanNormGenFp<float>(5, -2, 10000000, 0.2, 5.0); }
|
||||
|
||||
TEST(TDevRanNormGenFp, F16) { TDevRanNormGenFp<ck::half_t>(5, -2, 100000); }
|
||||
TEST(TDevRanNormGenFp, BF16) { TDevRanNormGenFp<ck::bhalf_t>(5, -2, 100000); }
|
||||
|
||||
TEST(TDevRanNormGenFp, F8) { TDevRanNormGenFp<ck::f8_t>(2, 0, 100000, 0.5, 2.0, 10.0); }
|
||||
TEST(TDevRanNormGenFp, BF8) { TDevRanNormGenFp<ck::bf8_t>(16, 0, 100000, 0.5, 2.0, 30.0); }
|
||||
|
||||
TEST(TDevRanNormGenFp, F4) { TDevRanNormGenFp<ck::f4x2_pk_t>(2, 0, 100000, 0.5, 3.0, 30.0); }
|
||||
Reference in New Issue
Block a user