CK_TILE_HOST_DEVICE in philox

This commit is contained in:
danyao12
2024-05-29 16:20:34 +08:00
parent 1c511b3e7d
commit 58f61716b5

View File

@@ -3,13 +3,15 @@
#pragma once
#include "ck_tile/core/config.hpp"
namespace ck_tile {
// Reference: https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/src/philox.cuh
class philox
{
public:
__host__ __device__ inline philox(unsigned long long seed_, unsigned long long offset_)
CK_TILE_HOST_DEVICE philox(unsigned long long seed_, unsigned long long offset_)
: seed(reinterpret_cast<const uint2&>(seed_))
{
@@ -17,7 +19,7 @@ class philox
tmp->x = offset_;
}
__host__ __device__ inline uint4 get_philox_4x32(const unsigned long long subsequence) const
CK_TILE_HOST_DEVICE uint4 get_philox_4x32(const unsigned long long subsequence) const
{
uint4 counter_ = counter;
@@ -37,7 +39,7 @@ class philox
return output;
}
__host__ __device__ void get_random_16x8(uint8_t* out,
CK_TILE_HOST_DEVICE void get_random_16x8(uint8_t* out,
const unsigned long long subsequence) const
{
uint4 tmp_ph;
@@ -60,7 +62,7 @@ class philox
uint4 counter;
const uint2 seed;
__host__ __device__ uint2 mulhilo32(const unsigned int a, const unsigned int b) const
CK_TILE_HOST_DEVICE uint2 mulhilo32(const unsigned int a, const unsigned int b) const
{
uint2* res;
unsigned long long tmp;
@@ -69,7 +71,7 @@ class philox
return *res;
}
__host__ __device__ inline uint4 philox_single_round(const uint4 ctr, const uint2 key) const
CK_TILE_HOST_DEVICE uint4 philox_single_round(const uint4 ctr, const uint2 key) const
{
uint2 res0 = mulhilo32(kPhiloxSA, ctr.x);