mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 05:01:25 +00:00
* chore(copyright): update copyright header for codegen directory * chore(copyright): update copyright header for example directory
115 lines
3.0 KiB
C++
115 lines
3.0 KiB
C++
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
// SPDX-License-Identifier: MIT
|
|
|
|
#pragma once
|
|
|
|
#include <ostream>
|
|
#include <string>
|
|
#include "ck_tile/core.hpp"
|
|
#include "ck_tile/ops/fmha.hpp"
|
|
|
|
// keep sync with BlockAttentionBiasEnum
|
|
enum class bias_enum
|
|
{
|
|
no_bias = 0,
|
|
elementwise_bias = 1,
|
|
alibi = 2,
|
|
};
|
|
|
|
struct bias_info
|
|
{
|
|
bias_enum type;
|
|
/*
|
|
* simple dispatch logic
|
|
*
|
|
* if type == elementwise_bias:
|
|
* if rank_info == 0:
|
|
* bias is 1*1*s*s
|
|
* elif rank_info == 1:
|
|
* bias is 1*h*s*s
|
|
* elif rank_info == 2:
|
|
* bias is b*h*s*s
|
|
*
|
|
* elif type == alibi:
|
|
* if rank_info == 0:
|
|
* alibi in 1*h
|
|
* elif rank_info == 1:
|
|
* alibi in b*h
|
|
*/
|
|
int rank_info;
|
|
|
|
void serialize(std::ostream& os) const
|
|
{
|
|
if(type == bias_enum::no_bias)
|
|
os << "n";
|
|
else if(type == bias_enum::elementwise_bias)
|
|
{
|
|
os << "e";
|
|
if(rank_info != 0)
|
|
{
|
|
os << "[" << rank_info << "]";
|
|
}
|
|
}
|
|
else if(type == bias_enum::alibi)
|
|
{
|
|
os << "alibi";
|
|
if(rank_info != 0)
|
|
{
|
|
os << "[" << rank_info << "]";
|
|
}
|
|
}
|
|
}
|
|
|
|
static bias_info decode(std::string str)
|
|
{
|
|
bias_info info{bias_enum::no_bias, 0};
|
|
auto found_0 = str.find(':');
|
|
if(found_0 != std::string::npos)
|
|
{
|
|
std::string t = str.substr(0, found_0);
|
|
std::string v = str.substr(found_0 + 1);
|
|
if(t == "e" || t == "elementwise")
|
|
{
|
|
info.type = bias_enum::elementwise_bias;
|
|
info.rank_info = std::stoi(v);
|
|
if(info.rank_info < 0 || info.rank_info > 2)
|
|
throw std::invalid_argument("invalid bias rank: " + str);
|
|
}
|
|
else if(t == "a" || t == "alibi")
|
|
{
|
|
info.type = bias_enum::alibi;
|
|
info.rank_info = std::stoi(v);
|
|
if(info.rank_info < 0 || info.rank_info > 1)
|
|
throw std::invalid_argument("invalid bias rank: " + str);
|
|
}
|
|
else
|
|
{
|
|
throw std::invalid_argument("invalid bias value: " + str);
|
|
}
|
|
}
|
|
else if(str == "0" || str == "n")
|
|
{
|
|
info.type = bias_enum::no_bias;
|
|
}
|
|
else if(str == "1" || str == "e" || str == "elementwise")
|
|
{
|
|
info.type = bias_enum::elementwise_bias;
|
|
}
|
|
else if(str == "2" || str == "a" || str == "alibi")
|
|
{
|
|
info.type = bias_enum::alibi;
|
|
}
|
|
else
|
|
{
|
|
throw std::invalid_argument("invalid bias value: " + str);
|
|
}
|
|
return info;
|
|
}
|
|
|
|
friend std::ostream& operator<<(std::ostream& os, const bias_info& bi)
|
|
{
|
|
bi.serialize(os);
|
|
return os;
|
|
}
|
|
};
|