From 58f61716b5ef5b5ee662cbad3306f1ccb96adca3 Mon Sep 17 00:00:00 2001 From: danyao12 Date: Wed, 29 May 2024 16:20:34 +0800 Subject: [PATCH] CK_TILE_HOST_DEVICE in philox --- include/ck_tile/core/utility/philox_rand.hpp | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/include/ck_tile/core/utility/philox_rand.hpp b/include/ck_tile/core/utility/philox_rand.hpp index d68381e369..c49f44ae48 100644 --- a/include/ck_tile/core/utility/philox_rand.hpp +++ b/include/ck_tile/core/utility/philox_rand.hpp @@ -3,13 +3,15 @@ #pragma once +#include "ck_tile/core/config.hpp" + namespace ck_tile { // Reference: https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/src/philox.cuh class philox { public: - __host__ __device__ inline philox(unsigned long long seed_, unsigned long long offset_) + CK_TILE_HOST_DEVICE philox(unsigned long long seed_, unsigned long long offset_) : seed(reinterpret_cast(seed_)) { @@ -17,7 +19,7 @@ class philox tmp->x = offset_; } - __host__ __device__ inline uint4 get_philox_4x32(const unsigned long long subsequence) const + CK_TILE_HOST_DEVICE uint4 get_philox_4x32(const unsigned long long subsequence) const { uint4 counter_ = counter; @@ -37,7 +39,7 @@ class philox return output; } - __host__ __device__ void get_random_16x8(uint8_t* out, + CK_TILE_HOST_DEVICE void get_random_16x8(uint8_t* out, const unsigned long long subsequence) const { uint4 tmp_ph; @@ -60,7 +62,7 @@ class philox uint4 counter; const uint2 seed; - __host__ __device__ uint2 mulhilo32(const unsigned int a, const unsigned int b) const + CK_TILE_HOST_DEVICE uint2 mulhilo32(const unsigned int a, const unsigned int b) const { uint2* res; unsigned long long tmp; @@ -69,7 +71,7 @@ class philox return *res; } - __host__ __device__ inline uint4 philox_single_round(const uint4 ctr, const uint2 key) const + CK_TILE_HOST_DEVICE uint4 philox_single_round(const uint4 ctr, const uint2 key) const { uint2 res0 = mulhilo32(kPhiloxSA, ctr.x);