mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
fix xx
This commit is contained in:
@@ -72,7 +72,7 @@ auto get_elimit(int /*init_method*/)
|
||||
}
|
||||
|
||||
template <>
|
||||
auto get_elimit<ck_tile::bhalf_t>(int init_method)
|
||||
auto get_elimit<ck_tile::bf16_t>(int init_method)
|
||||
{
|
||||
if(init_method == 0)
|
||||
{
|
||||
@@ -510,7 +510,7 @@ int main(int argc, char* argv[])
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
return run<ck_tile::bhalf_t>(arg_parser) ? 0 : -2;
|
||||
return run<ck_tile::bf16_t>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
else if(data_type == "fp8")
|
||||
{
|
||||
|
||||
@@ -7,7 +7,6 @@
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/fmha.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "mask.hpp"
|
||||
|
||||
template <typename DataType>
|
||||
@@ -29,18 +28,18 @@ struct FmhaFwdTypeConfig<ck_tile::half_t>
|
||||
};
|
||||
|
||||
template <>
|
||||
struct FmhaFwdTypeConfig<ck_tile::bhalf_t>
|
||||
struct FmhaFwdTypeConfig<ck_tile::bf16_t>
|
||||
{
|
||||
using QDataType = ck_tile::bhalf_t;
|
||||
using KDataType = ck_tile::bhalf_t;
|
||||
using VDataType = ck_tile::bhalf_t;
|
||||
using BiasDataType = ck_tile::bhalf_t;
|
||||
using QDataType = ck_tile::bf16_t;
|
||||
using KDataType = ck_tile::bf16_t;
|
||||
using VDataType = ck_tile::bf16_t;
|
||||
using BiasDataType = ck_tile::bf16_t;
|
||||
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
|
||||
using SaccDataType = float; // data type for first gemm accumulation
|
||||
using SMPLComputeDataType = float; // data type for reduction, softmax
|
||||
using PDataType = ck_tile::bhalf_t; // data type for A matrix of second gemm
|
||||
using PDataType = ck_tile::bf16_t; // data type for A matrix of second gemm
|
||||
using OaccDataType = float; // data type for second gemm accumulation
|
||||
using ODataType = ck_tile::bhalf_t;
|
||||
using ODataType = ck_tile::bf16_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
|
||||
@@ -11,7 +11,7 @@ import copy
|
||||
|
||||
DTYPE_MAP = {
|
||||
"fp16": "ck_tile::half_t",
|
||||
"bf16": "ck_tile::bhalf_t",
|
||||
"bf16": "ck_tile::bf16_t",
|
||||
"fp8" : "ck_tile::fp8_t"
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user