mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 15:54:31 +00:00
Add compute data type alias for RoPE
This commit is contained in:
@@ -420,6 +420,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
using QDataType = typename TypeConfig::QDataType;
|
||||
using KDataType = typename TypeConfig::KDataType;
|
||||
using VDataType = typename TypeConfig::VDataType;
|
||||
using RoPEComputeDataType = typename TypeConfig::RoPEComputeDataType;
|
||||
using BiasDataType = typename TypeConfig::BiasDataType;
|
||||
using RandValOutputDataType = typename TypeConfig::RandValOutputDataType;
|
||||
using LSEDataType = typename TypeConfig::LSEDataType;
|
||||
@@ -1041,12 +1042,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
decltype(q_host_ref) q_host_ref_ro(q_host_ref.get_lengths());
|
||||
|
||||
ck_tile::reference_batched_rotary_position_embedding<float>(
|
||||
ck_tile::reference_batched_rotary_position_embedding<RoPEComputeDataType>(
|
||||
q_host_ref, rotary_cos_host, rotary_sin_host, is_rotary_interleaved, q_host_ref_ro);
|
||||
|
||||
q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host_ref_ro(i); });
|
||||
}
|
||||
|
||||
|
||||
if(i_perm) k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(b, i[0] / nr, i[1] + key_offset, i[2]); });
|
||||
else k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(b, i[1] + key_offset, i[0] / nr, i[2]); });
|
||||
|
||||
@@ -1064,7 +1065,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
knew_host_ref_ro.emplace(knew_host_ref.get_lengths());
|
||||
|
||||
ck_tile::reference_batched_rotary_position_embedding<float>(
|
||||
ck_tile::reference_batched_rotary_position_embedding<RoPEComputeDataType>(
|
||||
knew_host_ref,
|
||||
rotary_cos_host,
|
||||
rotary_sin_host,
|
||||
|
||||
@@ -20,6 +20,7 @@ struct FmhaFwdTypeConfig<ck_tile::half_t>
|
||||
using QDataType = ck_tile::half_t;
|
||||
using KDataType = ck_tile::half_t;
|
||||
using VDataType = ck_tile::half_t;
|
||||
using RoPEComputeDataType = float;
|
||||
using BiasDataType = ck_tile::half_t;
|
||||
using RandValOutputDataType = uint8_t;
|
||||
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
|
||||
@@ -36,6 +37,7 @@ struct FmhaFwdTypeConfig<ck_tile::bf16_t>
|
||||
using QDataType = ck_tile::bf16_t;
|
||||
using KDataType = ck_tile::bf16_t;
|
||||
using VDataType = ck_tile::bf16_t;
|
||||
using RoPEComputeDataType = float;
|
||||
using BiasDataType = ck_tile::bf16_t;
|
||||
using RandValOutputDataType = uint8_t;
|
||||
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
|
||||
@@ -52,6 +54,7 @@ struct FmhaFwdTypeConfig<ck_tile::fp8_t>
|
||||
using QDataType = ck_tile::fp8_t;
|
||||
using KDataType = ck_tile::fp8_t;
|
||||
using VDataType = ck_tile::fp8_t;
|
||||
using RoPEComputeDataType = float;
|
||||
using BiasDataType = float;
|
||||
using RandValOutputDataType = uint8_t;
|
||||
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
|
||||
@@ -68,6 +71,7 @@ struct FmhaFwdTypeConfig<ck_tile::bf8_t>
|
||||
using QDataType = ck_tile::bf8_t;
|
||||
using KDataType = ck_tile::bf8_t;
|
||||
using VDataType = ck_tile::bf8_t;
|
||||
using RoPEComputeDataType = float;
|
||||
using BiasDataType = ck_tile::bf8_t;
|
||||
using RandValOutputDataType = uint8_t;
|
||||
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
|
||||
|
||||
Reference in New Issue
Block a user