From f6850aef2990c1b3964bb1eb3358d07138fb7749 Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Mon, 15 Jul 2024 00:05:33 +0000 Subject: [PATCH] Add compute data type alias for RoPE --- example/ck_tile/01_fmha/fmha_fwd.cpp | 7 ++++--- example/ck_tile/01_fmha/fmha_fwd.hpp | 4 ++++ 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index ddae8d162f..3ab1862e35 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -420,6 +420,7 @@ bool run(const ck_tile::ArgParser& arg_parser) using QDataType = typename TypeConfig::QDataType; using KDataType = typename TypeConfig::KDataType; using VDataType = typename TypeConfig::VDataType; + using RoPEComputeDataType = typename TypeConfig::RoPEComputeDataType; using BiasDataType = typename TypeConfig::BiasDataType; using RandValOutputDataType = typename TypeConfig::RandValOutputDataType; using LSEDataType = typename TypeConfig::LSEDataType; @@ -1041,12 +1042,12 @@ bool run(const ck_tile::ArgParser& arg_parser) { decltype(q_host_ref) q_host_ref_ro(q_host_ref.get_lengths()); - ck_tile::reference_batched_rotary_position_embedding( + ck_tile::reference_batched_rotary_position_embedding( q_host_ref, rotary_cos_host, rotary_sin_host, is_rotary_interleaved, q_host_ref_ro); q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host_ref_ro(i); }); } - + if(i_perm) k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(b, i[0] / nr, i[1] + key_offset, i[2]); }); else k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(b, i[1] + key_offset, i[0] / nr, i[2]); }); @@ -1064,7 +1065,7 @@ bool run(const ck_tile::ArgParser& arg_parser) { knew_host_ref_ro.emplace(knew_host_ref.get_lengths()); - ck_tile::reference_batched_rotary_position_embedding( + ck_tile::reference_batched_rotary_position_embedding( knew_host_ref, rotary_cos_host, rotary_sin_host, diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index ba9cd06833..ec2aa85da2 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -20,6 +20,7 @@ struct FmhaFwdTypeConfig using QDataType = ck_tile::half_t; using KDataType = ck_tile::half_t; using VDataType = ck_tile::half_t; + using RoPEComputeDataType = float; using BiasDataType = ck_tile::half_t; using RandValOutputDataType = uint8_t; using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) @@ -36,6 +37,7 @@ struct FmhaFwdTypeConfig using QDataType = ck_tile::bf16_t; using KDataType = ck_tile::bf16_t; using VDataType = ck_tile::bf16_t; + using RoPEComputeDataType = float; using BiasDataType = ck_tile::bf16_t; using RandValOutputDataType = uint8_t; using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) @@ -52,6 +54,7 @@ struct FmhaFwdTypeConfig using QDataType = ck_tile::fp8_t; using KDataType = ck_tile::fp8_t; using VDataType = ck_tile::fp8_t; + using RoPEComputeDataType = float; using BiasDataType = float; using RandValOutputDataType = uint8_t; using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) @@ -68,6 +71,7 @@ struct FmhaFwdTypeConfig using QDataType = ck_tile::bf8_t; using KDataType = ck_tile::bf8_t; using VDataType = ck_tile::bf8_t; + using RoPEComputeDataType = float; using BiasDataType = ck_tile::bf8_t; using RandValOutputDataType = uint8_t; using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))