mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 15:54:31 +00:00
CK_TILE_HOST_DEVICE in philox
This commit is contained in:
@@ -3,13 +3,15 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Reference: https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/src/philox.cuh
|
||||
class philox
|
||||
{
|
||||
public:
|
||||
__host__ __device__ inline philox(unsigned long long seed_, unsigned long long offset_)
|
||||
CK_TILE_HOST_DEVICE philox(unsigned long long seed_, unsigned long long offset_)
|
||||
: seed(reinterpret_cast<const uint2&>(seed_))
|
||||
{
|
||||
|
||||
@@ -17,7 +19,7 @@ class philox
|
||||
tmp->x = offset_;
|
||||
}
|
||||
|
||||
__host__ __device__ inline uint4 get_philox_4x32(const unsigned long long subsequence) const
|
||||
CK_TILE_HOST_DEVICE uint4 get_philox_4x32(const unsigned long long subsequence) const
|
||||
{
|
||||
|
||||
uint4 counter_ = counter;
|
||||
@@ -37,7 +39,7 @@ class philox
|
||||
return output;
|
||||
}
|
||||
|
||||
__host__ __device__ void get_random_16x8(uint8_t* out,
|
||||
CK_TILE_HOST_DEVICE void get_random_16x8(uint8_t* out,
|
||||
const unsigned long long subsequence) const
|
||||
{
|
||||
uint4 tmp_ph;
|
||||
@@ -60,7 +62,7 @@ class philox
|
||||
uint4 counter;
|
||||
const uint2 seed;
|
||||
|
||||
__host__ __device__ uint2 mulhilo32(const unsigned int a, const unsigned int b) const
|
||||
CK_TILE_HOST_DEVICE uint2 mulhilo32(const unsigned int a, const unsigned int b) const
|
||||
{
|
||||
uint2* res;
|
||||
unsigned long long tmp;
|
||||
@@ -69,7 +71,7 @@ class philox
|
||||
return *res;
|
||||
}
|
||||
|
||||
__host__ __device__ inline uint4 philox_single_round(const uint4 ctr, const uint2 key) const
|
||||
CK_TILE_HOST_DEVICE uint4 philox_single_round(const uint4 ctr, const uint2 key) const
|
||||
{
|
||||
|
||||
uint2 res0 = mulhilo32(kPhiloxSA, ctr.x);
|
||||
|
||||
Reference in New Issue
Block a user