This commit is contained in:
carlushuang
2024-03-03 23:48:31 +00:00
parent fbd25cea35
commit 112d521b09
66 changed files with 1720 additions and 1498 deletions

View File

@@ -72,7 +72,7 @@ auto get_elimit(int /*init_method*/)
}
template <>
auto get_elimit<ck_tile::bhalf_t>(int init_method)
auto get_elimit<ck_tile::bf16_t>(int init_method)
{
if(init_method == 0)
{
@@ -510,7 +510,7 @@ int main(int argc, char* argv[])
}
else if(data_type == "bf16")
{
return run<ck_tile::bhalf_t>(arg_parser) ? 0 : -2;
return run<ck_tile::bf16_t>(arg_parser) ? 0 : -2;
}
else if(data_type == "fp8")
{

View File

@@ -7,7 +7,6 @@
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/fmha.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/common.hpp"
#include "mask.hpp"
template <typename DataType>
@@ -29,18 +28,18 @@ struct FmhaFwdTypeConfig<ck_tile::half_t>
};
template <>
struct FmhaFwdTypeConfig<ck_tile::bhalf_t>
struct FmhaFwdTypeConfig<ck_tile::bf16_t>
{
using QDataType = ck_tile::bhalf_t;
using KDataType = ck_tile::bhalf_t;
using VDataType = ck_tile::bhalf_t;
using BiasDataType = ck_tile::bhalf_t;
using QDataType = ck_tile::bf16_t;
using KDataType = ck_tile::bf16_t;
using VDataType = ck_tile::bf16_t;
using BiasDataType = ck_tile::bf16_t;
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
using SaccDataType = float; // data type for first gemm accumulation
using SMPLComputeDataType = float; // data type for reduction, softmax
using PDataType = ck_tile::bhalf_t; // data type for A matrix of second gemm
using PDataType = ck_tile::bf16_t; // data type for A matrix of second gemm
using OaccDataType = float; // data type for second gemm accumulation
using ODataType = ck_tile::bhalf_t;
using ODataType = ck_tile::bf16_t;
};
template <>

View File

@@ -11,7 +11,7 @@ import copy
DTYPE_MAP = {
"fp16": "ck_tile::half_t",
"bf16": "ck_tile::bhalf_t",
"bf16": "ck_tile::bf16_t",
"fp8" : "ck_tile::fp8_t"
}