Add compute data type alias for RoPE

This commit is contained in:
PoYen, Chen
2024-07-15 00:05:33 +00:00
parent b0925bb7f6
commit f6850aef29
2 changed files with 8 additions and 3 deletions

View File

@@ -420,6 +420,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
using QDataType = typename TypeConfig::QDataType;
using KDataType = typename TypeConfig::KDataType;
using VDataType = typename TypeConfig::VDataType;
using RoPEComputeDataType = typename TypeConfig::RoPEComputeDataType;
using BiasDataType = typename TypeConfig::BiasDataType;
using RandValOutputDataType = typename TypeConfig::RandValOutputDataType;
using LSEDataType = typename TypeConfig::LSEDataType;
@@ -1041,12 +1042,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
{
decltype(q_host_ref) q_host_ref_ro(q_host_ref.get_lengths());
ck_tile::reference_batched_rotary_position_embedding<float>(
ck_tile::reference_batched_rotary_position_embedding<RoPEComputeDataType>(
q_host_ref, rotary_cos_host, rotary_sin_host, is_rotary_interleaved, q_host_ref_ro);
q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host_ref_ro(i); });
}
if(i_perm) k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(b, i[0] / nr, i[1] + key_offset, i[2]); });
else k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(b, i[1] + key_offset, i[0] / nr, i[2]); });
@@ -1064,7 +1065,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
{
knew_host_ref_ro.emplace(knew_host_ref.get_lengths());
ck_tile::reference_batched_rotary_position_embedding<float>(
ck_tile::reference_batched_rotary_position_embedding<RoPEComputeDataType>(
knew_host_ref,
rotary_cos_host,
rotary_sin_host,

View File

@@ -20,6 +20,7 @@ struct FmhaFwdTypeConfig<ck_tile::half_t>
using QDataType = ck_tile::half_t;
using KDataType = ck_tile::half_t;
using VDataType = ck_tile::half_t;
using RoPEComputeDataType = float;
using BiasDataType = ck_tile::half_t;
using RandValOutputDataType = uint8_t;
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
@@ -36,6 +37,7 @@ struct FmhaFwdTypeConfig<ck_tile::bf16_t>
using QDataType = ck_tile::bf16_t;
using KDataType = ck_tile::bf16_t;
using VDataType = ck_tile::bf16_t;
using RoPEComputeDataType = float;
using BiasDataType = ck_tile::bf16_t;
using RandValOutputDataType = uint8_t;
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
@@ -52,6 +54,7 @@ struct FmhaFwdTypeConfig<ck_tile::fp8_t>
using QDataType = ck_tile::fp8_t;
using KDataType = ck_tile::fp8_t;
using VDataType = ck_tile::fp8_t;
using RoPEComputeDataType = float;
using BiasDataType = float;
using RandValOutputDataType = uint8_t;
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
@@ -68,6 +71,7 @@ struct FmhaFwdTypeConfig<ck_tile::bf8_t>
using QDataType = ck_tile::bf8_t;
using KDataType = ck_tile::bf8_t;
using VDataType = ck_tile::bf8_t;
using RoPEComputeDataType = float;
using BiasDataType = ck_tile::bf8_t;
using RandValOutputDataType = uint8_t;
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))