mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[CK_Builder] [testing] Integrate device random generators (#3427)
Implemented device random number generators for ck tensors. Includes tests and integration to ck builder testing interface.
This commit is contained in:
@@ -4,6 +4,8 @@
|
||||
#pragma once
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <math.h>
|
||||
#include "ck/library/utility/device_tensor_generator.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -34,6 +36,12 @@ struct DeviceMem
|
||||
void SetZero() const;
|
||||
template <typename T>
|
||||
void SetValue(T x) const;
|
||||
template <typename T>
|
||||
void FillUniformRandInteger(int min_value, int max_value);
|
||||
template <typename T>
|
||||
void FillUniformRandFp(float min_value, float max_value);
|
||||
template <typename T>
|
||||
void FillNormalRandFp(float sigma, float mean);
|
||||
~DeviceMem();
|
||||
|
||||
void* mpDeviceBuf;
|
||||
@@ -51,4 +59,48 @@ void DeviceMem::SetValue(T x) const
|
||||
set_buffer_value<T><<<1, 1024>>>(static_cast<T*>(mpDeviceBuf), x, mMemSize / sizeof(T));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void DeviceMem::FillUniformRandInteger(int min_value, int max_value)
|
||||
{
|
||||
if(mMemSize % sizeof(T) != 0)
|
||||
{
|
||||
throw std::runtime_error("wrong! not entire DeviceMem will be filled");
|
||||
}
|
||||
if(max_value - min_value <= 1)
|
||||
{
|
||||
throw std::runtime_error("Error while filling device tensor with random integer data: max "
|
||||
"value must be at least 2 greater than min value, otherwise "
|
||||
"tensor will be filled by a constant value (end is exclusive)");
|
||||
}
|
||||
if(max_value - 1 == min_value || max_value - 1 == max_value)
|
||||
{
|
||||
throw std::runtime_error("Error while filling device tensor with random integer data: "
|
||||
"insufficient precision in specified range");
|
||||
}
|
||||
|
||||
size_t packed_size = packed_size_v<T>;
|
||||
fill_tensor_uniform_rand_int_values<<<256, 256>>>(
|
||||
static_cast<T*>(mpDeviceBuf), min_value, max_value, (mMemSize * packed_size) / sizeof(T));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void DeviceMem::FillUniformRandFp(float min_value, float max_value)
|
||||
{
|
||||
if(mMemSize % sizeof(T) != 0)
|
||||
{
|
||||
throw std::runtime_error("wrong! not entire DeviceMem will be filled");
|
||||
}
|
||||
|
||||
size_t packed_size = packed_size_v<T>;
|
||||
fill_tensor_uniform_rand_fp_values<<<256, 256>>>(
|
||||
static_cast<T*>(mpDeviceBuf), min_value, max_value, (mMemSize * packed_size) / sizeof(T));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void DeviceMem::FillNormalRandFp(float sigma, float mean)
|
||||
{
|
||||
|
||||
fill_tensor_norm_rand_fp_values<<<256, 256>>>(
|
||||
static_cast<T*>(mpDeviceBuf), sigma, mean, mMemSize / sizeof(T));
|
||||
}
|
||||
} // namespace ck
|
||||
|
||||
135
include/ck/library/utility/device_tensor_generator.hpp
Normal file
135
include/ck/library/utility/device_tensor_generator.hpp
Normal file
@@ -0,0 +1,135 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
#pragma once
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/library/utility/device_tensor_generator.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include <cmath>
|
||||
|
||||
// use xorshift for now since it is simple. Should be suitable enough, but feel free to switch in
|
||||
// the future
|
||||
struct ran_state_u32
|
||||
{
|
||||
uint32_t s[4];
|
||||
};
|
||||
|
||||
__device__ uint32_t ran_gen_round_u32(ran_state_u32& state)
|
||||
{
|
||||
uint32_t tmp = state.s[3];
|
||||
state.s[3] = state.s[2];
|
||||
state.s[2] = state.s[1];
|
||||
state.s[1] = state.s[0];
|
||||
tmp ^= tmp << 11;
|
||||
tmp ^= tmp >> 8;
|
||||
state.s[0] = tmp ^ state.s[0] ^ (state.s[0] >> 19);
|
||||
return state.s[0];
|
||||
}
|
||||
|
||||
__device__ ran_state_u32 ran_init(uint32_t seed = 0)
|
||||
{
|
||||
ran_state_u32 state;
|
||||
// use primes for initialization
|
||||
state.s[0] = (blockDim.x * blockIdx.x + threadIdx.x) * 8912741 + 2313212 + seed;
|
||||
state.s[1] =
|
||||
(gridDim.x * blockDim.x - (blockDim.x * blockIdx.x + threadIdx.x)) * 5013829 + 6012697;
|
||||
state.s[2] = (blockDim.x * blockIdx.x + threadIdx.x) * 3412309 + 2912479;
|
||||
state.s[3] =
|
||||
(gridDim.x * blockDim.x - (blockDim.x * blockIdx.x + threadIdx.x)) * 1001447 + 9912307;
|
||||
|
||||
// run 20 rounds
|
||||
for(int i = 0; i < 20; i++)
|
||||
{
|
||||
ran_gen_round_u32(state);
|
||||
}
|
||||
return state;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void fill_tensor_uniform_rand_int_values(T* p,
|
||||
int min_value,
|
||||
int max_value,
|
||||
uint64_t buffer_element_size)
|
||||
{
|
||||
// initial values
|
||||
ran_state_u32 s = ran_init();
|
||||
for(uint64_t i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
i < buffer_element_size / ck::packed_size_v<T>;
|
||||
i += blockDim.x * gridDim.x)
|
||||
{
|
||||
if constexpr(ck::is_same_v<T, ck::pk_i4_t>)
|
||||
{
|
||||
uint8_t hi = ((ran_gen_round_u32(s)) % (max_value - min_value)) + min_value + 8;
|
||||
uint8_t lo = ((ran_gen_round_u32(s)) % (max_value - min_value)) + min_value + 8;
|
||||
ck::pk_i4_t res = ((hi & 0xf) << 4) + (lo & 0xf);
|
||||
p[i] = res;
|
||||
}
|
||||
else
|
||||
{
|
||||
p[i] = ck::type_convert<T, int>(
|
||||
static_cast<int>((ran_gen_round_u32(s)) % (max_value - min_value)) + min_value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void fill_tensor_uniform_rand_fp_values(T* p,
|
||||
float min_value,
|
||||
float max_value,
|
||||
uint64_t buffer_element_size)
|
||||
{
|
||||
// initial values
|
||||
ran_state_u32 s = ran_init();
|
||||
for(uint64_t i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
i < buffer_element_size / ck::packed_size_v<T>;
|
||||
i += blockDim.x * gridDim.x)
|
||||
{
|
||||
if constexpr(ck::is_same_v<T, ck::f4x2_pk_t>)
|
||||
{
|
||||
float u1 =
|
||||
ran_gen_round_u32(s) * (1.0f / 4294967296.0f) * (max_value - min_value) + min_value;
|
||||
float u2 =
|
||||
ran_gen_round_u32(s) * (1.0f / 4294967296.0f) * (max_value - min_value) + min_value;
|
||||
|
||||
p[i] = ck::type_convert<ck::f4x2_t>(ck::float2_t{u1, u2});
|
||||
}
|
||||
else
|
||||
{
|
||||
float ran = ran_gen_round_u32(s) * (1.0f / 4294967296.0f);
|
||||
p[i] = ck::type_convert<T, float>(ran * (max_value - min_value) + min_value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void
|
||||
fill_tensor_norm_rand_fp_values(T* p, float sigma, float mean, uint64_t buffer_element_size)
|
||||
{
|
||||
// initial values
|
||||
ran_state_u32 s = ran_init();
|
||||
float norm[2];
|
||||
for(uint64_t i = blockIdx.x * blockDim.x + threadIdx.x, j = 0; i < buffer_element_size;
|
||||
i += blockDim.x * gridDim.x, j++)
|
||||
{
|
||||
if(j % (2 / ck::packed_size_v<T>) == 0)
|
||||
{
|
||||
float u1 = ran_gen_round_u32(s) * (1.0f / 4294967296.0f);
|
||||
float u2 = ran_gen_round_u32(s) * (1.0f / 4294967296.0f);
|
||||
norm[0] =
|
||||
sigma * std::sqrt(-2.0f * ck::math::log(u1)) * std::cos(2.0f * M_PI * u2) + mean;
|
||||
norm[1] =
|
||||
sigma * std::sqrt(-2.0f * ck::math::log(u1)) * std::sin(2.0f * M_PI * u2) + mean;
|
||||
}
|
||||
|
||||
if constexpr(ck::is_same_v<T, ck::f4x2_pk_t>)
|
||||
{
|
||||
p[i] = ck::type_convert<ck::f4x2_t>(ck::float2_t{norm[0], norm[1]});
|
||||
}
|
||||
else
|
||||
{
|
||||
p[i] = ck::type_convert<T, float>(norm[j % 2]);
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user