mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 11:30:02 +00:00
[CK_TILE] support alibi (#1269)
* add alibi support
* fix code
* update code based on comment
* Support more hdim
* fix fp8 bias
* support seqlen_k=0 case
* remove unused printf
* fix format
---------
Co-authored-by: rocking <ChunYu.Lai@amd.com>
[ROCm/composable_kernel commit: 851c3ed157]
This commit is contained in:
@@ -17,7 +17,7 @@ add_custom_command(
|
||||
set(EXAMPLE_FMHA_FWD "tile_example_fmha_fwd")
|
||||
# not using add_example_executable() to add this target, since we don't want this to have
|
||||
# to be included in "make all/install/check"
|
||||
message("adding tile_example ${EXAMPLE_NAME}")
|
||||
message("adding example ${EXAMPLE_FMHA_FWD}")
|
||||
add_executable(${EXAMPLE_FMHA_FWD} EXCLUDE_FROM_ALL fmha_fwd.cpp)
|
||||
target_include_directories(${EXAMPLE_FMHA_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
target_sources(${EXAMPLE_FMHA_FWD} PRIVATE ${FMHA_FWD_GEN_BLOBS})
|
||||
|
||||
@@ -30,27 +30,29 @@ args:
|
||||
-mode kernel mode. 0:batch, 1:group (default:0)
|
||||
-b batch size (default:2)
|
||||
-h num of head, for q (default:8)
|
||||
-h_k num of head, for k/v, 0 means equal to h (default:0)
|
||||
-h_k num of head, for k/v, -1 means equal to h (default:-1)
|
||||
if not equal to h, then this is GQA/MQA case
|
||||
-s seqlen_q. if group-mode, means the average value of seqlen_q (default:3328)
|
||||
total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary
|
||||
-s_k seqlen_k, 0 means equal to s (default:0)
|
||||
-s_k seqlen_k, -1 means equal to s (default:-1)
|
||||
-d head dim for q, k (default:128)
|
||||
-d_v head dim for v, 0 means equal to d (default:0)
|
||||
-d_v head dim for v, -1 means equal to d (default:-1)
|
||||
-scale_s scale factor of S. 0 means equal to 1/sqrt(hdim). (default:0)
|
||||
note when squant=1, this value will be modified by range_q/k
|
||||
-range_q per-tensor quantization range of q. used if squant=1. (default:2)
|
||||
-range_k per-tensor quantization range of k. used if squant=1. (default:2)
|
||||
-range_v per-tensor quantization range of v. used if squant=1. (default:2)
|
||||
-range_q per-tensor quantization range of q. used if squant=1. (default:16)
|
||||
-range_k per-tensor quantization range of k. used if squant=1. (default:16)
|
||||
-range_v per-tensor quantization range of v. used if squant=1. (default:16)
|
||||
-range_p per-tensor quantization range of p [e^(s-m)]. used if squant=1. (default:1)
|
||||
-range_o per-tensor quantization range of o (p*v). used if squant=1. (default:2)
|
||||
-range_o per-tensor quantization range of o (p*v). used if squant=1. (default:16)
|
||||
-squant if using static quantization fusion or not. 0: original flow(not prefered) (default:0)
|
||||
1: apply scale_p and scale_o with respect to P and O. calculate scale_s, scale_p,
|
||||
scale_o according to range_q, range_k, range_v, range_p, range_o
|
||||
-iperm permute input (default:1)
|
||||
if true, will be b*h*s*d, else b*s*h*d
|
||||
-operm permute output (default:1)
|
||||
-bias add bias or not (default:0)
|
||||
-bias n or 0, no bias (default:n)
|
||||
e(lementwise) or 1, elementwise bias with 1*1*s*s. e:1, 1*h*s*s. e:2, b*h*s*s
|
||||
a(libi) or 2, alibi with 1*h. a:1, b*h
|
||||
-prec data type. fp16/bf16/fp8/bf8 (default:fp16)
|
||||
-mask 0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b') (default:0)
|
||||
't', top-left causal mask, 'b', bottom-r causal mask
|
||||
@@ -59,11 +61,11 @@ args:
|
||||
'xt:window_size', xformer style masking from top-left, window_size negative is causal, positive is swa
|
||||
'xb:window_size', xformer style masking from bottom-r, window_size negative is causal, positive is swa
|
||||
'g:y,x', generic attention mask coordinate with y/x size (only debug purpose for now)
|
||||
|
||||
-vlayout r for row-major(seqlen*hdim), c for col-major(hdim*seqlen) (default:r)
|
||||
-lse 0 not store lse, 1 store lse (default:0)
|
||||
-kname if set to 1 will print kernel name (default:0)
|
||||
-init init method. 0:random int, 1:random float, 2:trig float, 3:quantization (default:1)
|
||||
-seed random seed used for initializing input tensors. 0 for non-deterministic seed (default:11939)
|
||||
```
|
||||
Example: `./bin/tile_example_fmha_fwd -b=1 -h=16 -s=16384 -d=128` will run a fmha case with batch=1, nhead=16, sequence length=16384, hdim=128, fp16 case.
|
||||
|
||||
@@ -85,6 +87,9 @@ If you look at the kernel argument inside `fmha_fwd_kernel.hpp`, we support prov
|
||||
### attention bias
|
||||
Attention bias is supported with the layout of `1*1*s*s`(similiar to input/output, different layout can be supported by changing the stride value for bias, or even extend to `b*h*s*s`) and bias value in float number.
|
||||
|
||||
### alibi
|
||||
alibi is supported
|
||||
|
||||
### lse
|
||||
For training kernels, "log sum exp" need to store out in forward and used in backward. We support this by setting `-lse=1`
|
||||
|
||||
|
||||
100
example/ck_tile/01_fmha/bias.hpp
Normal file
100
example/ck_tile/01_fmha/bias.hpp
Normal file
@@ -0,0 +1,100 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha.hpp"
|
||||
|
||||
// keep sync with BlockAttentionBiasEnum
|
||||
enum class bias_enum
|
||||
{
|
||||
no_bias = 0,
|
||||
elementwise_bias = 1,
|
||||
alibi = 2,
|
||||
};
|
||||
|
||||
struct bias_info
|
||||
{
|
||||
bias_enum type;
|
||||
/*
|
||||
* simple dispatch logic
|
||||
*
|
||||
* if type == elementwise_bias:
|
||||
* if rank_info == 0:
|
||||
* bias is 1*1*s*s
|
||||
* elif rank_info == 1:
|
||||
* bias is 1*h*s*s
|
||||
* elif rank_info == 2:
|
||||
* bias is b*h*s*s
|
||||
*
|
||||
* elif type == alibi:
|
||||
* if rank_info == 0:
|
||||
* alibi in 1*h
|
||||
* elif rank_info == 1:
|
||||
* alibi in b*h
|
||||
*/
|
||||
int rank_info;
|
||||
|
||||
void serialize(std::ostream& os) const
|
||||
{
|
||||
if(type == bias_enum::no_bias)
|
||||
os << "n";
|
||||
else if(type == bias_enum::elementwise_bias)
|
||||
{
|
||||
os << "e";
|
||||
if(rank_info != 0)
|
||||
{
|
||||
os << "[" << rank_info << "]";
|
||||
}
|
||||
}
|
||||
else if(type == bias_enum::alibi)
|
||||
{
|
||||
os << "alibi";
|
||||
if(rank_info != 0)
|
||||
{
|
||||
os << "[" << rank_info << "]";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static bias_info decode(std::string str)
|
||||
{
|
||||
bias_info info{bias_enum::no_bias, 0};
|
||||
if(str == "0" || str == "n")
|
||||
{
|
||||
info.type = bias_enum::no_bias;
|
||||
}
|
||||
else if(str.compare(0, 1, "1") == 0 || str.compare(0, 1, "e") == 0 ||
|
||||
str.compare(0, 11, "elementwise") == 0)
|
||||
{
|
||||
info.type = bias_enum::elementwise_bias;
|
||||
auto found_0 = str.find(':');
|
||||
if(found_0 != std::string::npos)
|
||||
{
|
||||
std::string e = str.substr(found_0 + 1);
|
||||
info.rank_info = atoi(e.c_str());
|
||||
}
|
||||
}
|
||||
else if(str.compare(0, 1, "2") == 0 || str.compare(0, 1, "a") == 0 ||
|
||||
str.compare(0, 5, "alibi") == 0)
|
||||
{
|
||||
info.type = bias_enum::alibi;
|
||||
auto found_0 = str.find(':');
|
||||
if(found_0 != std::string::npos)
|
||||
{
|
||||
std::string e = str.substr(found_0 + 1);
|
||||
info.rank_info = atoi(e.c_str());
|
||||
}
|
||||
}
|
||||
return info;
|
||||
}
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, const bias_info& bi)
|
||||
{
|
||||
bi.serialize(os);
|
||||
return os;
|
||||
}
|
||||
};
|
||||
@@ -41,16 +41,16 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("b", "2", "batch size")
|
||||
.insert("h", "8", "num of head, for q")
|
||||
.insert("h_k",
|
||||
"0",
|
||||
"num of head, for k/v, 0 means equal to h\n"
|
||||
"-1",
|
||||
"num of head, for k/v, -1 means equal to h\n"
|
||||
"if not equal to h, then this is GQA/MQA case")
|
||||
.insert("s",
|
||||
"3328",
|
||||
"seqlen_q. if group-mode, means the average value of seqlen_q\n"
|
||||
"total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary")
|
||||
.insert("s_k", "0", "seqlen_k, 0 means equal to s")
|
||||
.insert("s_k", "-1", "seqlen_k, -1 means equal to s")
|
||||
.insert("d", "128", "head dim for q, k")
|
||||
.insert("d_v", "0", "head dim for v, 0 means equal to d")
|
||||
.insert("d_v", "-1", "head dim for v, -1 means equal to d")
|
||||
.insert("scale_s",
|
||||
"0",
|
||||
"scale factor of S. 0 means equal to 1/sqrt(hdim).\n"
|
||||
@@ -71,7 +71,11 @@ auto create_args(int argc, char* argv[])
|
||||
"permute input\n"
|
||||
"if true, will be b*h*s*d, else b*s*h*d")
|
||||
.insert("operm", "1", "permute output")
|
||||
.insert("bias", "0", "add bias or not")
|
||||
.insert("bias",
|
||||
"n",
|
||||
"n or 0, no bias\n"
|
||||
"e(lementwise) or 1, elementwise bias with 1*1*s*s. e:1, 1*h*s*s. e:2, b*h*s*s\n"
|
||||
"a(libi) or 2, alibi with 1*h. a:1, b*h")
|
||||
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
|
||||
.insert("mask",
|
||||
"0",
|
||||
@@ -153,7 +157,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::index_t batch = arg_parser.get_int("b");
|
||||
ck_tile::index_t nhead = arg_parser.get_int("h");
|
||||
ck_tile::index_t nhead_k = arg_parser.get_int("h_k");
|
||||
if(nhead_k == 0)
|
||||
if(nhead_k < 0)
|
||||
nhead_k = nhead;
|
||||
|
||||
if(nhead % nhead_k != 0)
|
||||
@@ -164,11 +168,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
ck_tile::index_t seqlen_q = arg_parser.get_int("s");
|
||||
ck_tile::index_t seqlen_k = arg_parser.get_int("s_k");
|
||||
if(seqlen_k == 0)
|
||||
if(seqlen_k < 0)
|
||||
seqlen_k = seqlen_q;
|
||||
ck_tile::index_t hdim_q = arg_parser.get_int("d");
|
||||
ck_tile::index_t hdim_v = arg_parser.get_int("d_v");
|
||||
if(hdim_v == 0)
|
||||
if(hdim_v < 0)
|
||||
hdim_v = hdim_q;
|
||||
|
||||
bool i_perm = arg_parser.get_bool("iperm"); // if true, will be batch * nhead * seqlen * hdim
|
||||
@@ -208,9 +212,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
}
|
||||
|
||||
std::string vlayout = arg_parser.get_str("vlayout");
|
||||
bool use_bias = arg_parser.get_bool("bias");
|
||||
bool lse = arg_parser.get_bool("lse");
|
||||
|
||||
bias_info bias = bias_info::decode(arg_parser.get_str("bias"));
|
||||
mask_info mask = mask_info::decode(arg_parser.get_str("mask"), seqlen_q, seqlen_k);
|
||||
|
||||
int init_method = arg_parser.get_int("init");
|
||||
@@ -295,12 +299,18 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::HostTensor<VDataType> v_host(
|
||||
is_v_rowmajor ? get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_v)
|
||||
: get_lengths(i_perm, shape_batch, nhead_k, hdim_v, shape_seqlen_k));
|
||||
// use bias shape = [1, 1, shape_seqlen_q, shape_seqlen_k]. if use_bias=false, the bias_host
|
||||
// will not be used for verification at all (but will be copied to device anyway).
|
||||
|
||||
ck_tile::HostTensor<BiasDataType> bias_host(
|
||||
use_bias
|
||||
bias.type == bias_enum::elementwise_bias
|
||||
? get_lengths(i_perm, 1, 1, shape_seqlen_q, shape_seqlen_k)
|
||||
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1} /* dummy shape for simplifying code */);
|
||||
|
||||
ck_tile::HostTensor<SaccDataType> alibi_slope_host(
|
||||
bias.type == bias_enum::alibi
|
||||
? (bias.rank_info == 0 ? std::array<ck_tile::index_t, 2>{1, nhead}
|
||||
: std::array<ck_tile::index_t, 2>{batch, nhead})
|
||||
: std::array<ck_tile::index_t, 2>{1, 1});
|
||||
|
||||
// self define lse data layout as [shape_batch, nhead, shape_seqlen_q]
|
||||
ck_tile::HostTensor<LSEDataType> lse_host(
|
||||
lse ? std::array<ck_tile::index_t, 3>{shape_batch, nhead, shape_seqlen_q}
|
||||
@@ -341,6 +351,24 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
// Assume bias is in [-1.f, 1.f] in original fp32
|
||||
ck_tile::FillUniformDistribution<BiasDataType>{-qscale_bias, qscale_bias, seed}(bias_host);
|
||||
}
|
||||
if(bias.type == bias_enum::alibi)
|
||||
{
|
||||
auto slopes = ck_tile::get_alibi_slopes<SaccDataType>(nhead);
|
||||
assert(slopes.size() == nhead);
|
||||
if(bias.rank_info == 0)
|
||||
{
|
||||
// alibi in 1*h
|
||||
std::copy(slopes.begin(), slopes.end(), alibi_slope_host.begin());
|
||||
}
|
||||
else
|
||||
{
|
||||
// alibi in b*h
|
||||
for(auto i_b = 0; i_b < batch; i_b++)
|
||||
{
|
||||
std::copy(slopes.begin(), slopes.end(), alibi_slope_host.begin() + i_b * nhead);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes());
|
||||
@@ -350,6 +378,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t));
|
||||
ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t));
|
||||
ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes());
|
||||
|
||||
q_buf.ToDevice(q_host.data());
|
||||
k_buf.ToDevice(k_host.data());
|
||||
@@ -357,6 +386,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
bias_buf.ToDevice(bias_host.data());
|
||||
seqstart_q.ToDevice(seqstart_q_host.data());
|
||||
seqstart_k.ToDevice(seqstart_k_host.data());
|
||||
alibi_slope_buf.ToDevice(alibi_slope_host.data());
|
||||
|
||||
// clang-format off
|
||||
auto layout_str = [&](bool permute){
|
||||
@@ -372,9 +402,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
std::cout << "[" << prec << "|" << mode << "|" << io_layout(i_perm, o_perm) << "] b:" << batch
|
||||
<< ", h:" << nhead << "/" << nhead_k << ", s:" << seqlen_q << "/" << seqlen_k
|
||||
<< ", d:" << hdim_q << "/" << hdim_v << ", scale_s:" << scale_s
|
||||
<< ", bias:" << use_bias << ", lse:" << lse << ", squant:" << squant
|
||||
<< ", mask:" << mask << ", v:" << vlayout << std::flush;
|
||||
<< ", d:" << hdim_q << "/" << hdim_v << ", scale_s:" << scale_s << ", bias:" << bias
|
||||
<< ", lse:" << lse << ", squant:" << squant << ", mask:" << mask << ", v:" << vlayout
|
||||
<< std::flush;
|
||||
|
||||
auto fmha_traits = fmha_fwd_traits{hdim_q,
|
||||
hdim_v,
|
||||
@@ -382,7 +412,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
mode == mode_enum::group,
|
||||
is_v_rowmajor,
|
||||
mask.type,
|
||||
use_bias,
|
||||
bias.type,
|
||||
lse,
|
||||
squant};
|
||||
|
||||
@@ -441,7 +471,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
return fmha_fwd_args{q_buf.GetDeviceBuffer(),
|
||||
k_buf.GetDeviceBuffer(),
|
||||
v_buf.GetDeviceBuffer(),
|
||||
bias_buf.GetDeviceBuffer(),
|
||||
bias.type == bias_enum::alibi ? alibi_slope_buf.GetDeviceBuffer()
|
||||
: bias_buf.GetDeviceBuffer(),
|
||||
lse_buf.GetDeviceBuffer(),
|
||||
o_buf.GetDeviceBuffer(),
|
||||
seqstart_q.GetDeviceBuffer(),
|
||||
@@ -461,7 +492,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
stride_q,
|
||||
stride_k,
|
||||
stride_v,
|
||||
stride_bias,
|
||||
bias.type == bias_enum::alibi ? (bias.rank_info == 0 ? 0 : nhead)
|
||||
: stride_bias,
|
||||
stride_o,
|
||||
nhead_stride_q,
|
||||
nhead_stride_k,
|
||||
@@ -564,8 +596,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::identity{},
|
||||
ck_tile::scales(scale_s));
|
||||
|
||||
if(use_bias)
|
||||
if(bias.type == bias_enum::elementwise_bias)
|
||||
{
|
||||
// elementwise bias
|
||||
ck_tile::HostTensor<BiasDataType> bias_host_ref({1, real_seqlen_q, real_seqlen_k});
|
||||
// clang-format off
|
||||
if(i_perm)
|
||||
@@ -582,6 +615,51 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
SMPLComputeDataType>(
|
||||
s_host_ref, bias_host_ref, s_host_ref);
|
||||
}
|
||||
else if(bias.type == bias_enum::alibi)
|
||||
{
|
||||
// alibi construct elementwise bias to verify
|
||||
auto alibi_host = [&]() {
|
||||
if(mask.type != mask_enum::no_mask)
|
||||
{
|
||||
return ck_tile::make_alibi_from_lr_mask<SaccDataType, true>(
|
||||
0,
|
||||
mask.left,
|
||||
mask.right,
|
||||
real_seqlen_q,
|
||||
real_seqlen_k,
|
||||
static_cast<ck_tile::GenericAttentionMaskEnum>(mask.type));
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck_tile::Alibi<SaccDataType, true>{
|
||||
0, real_seqlen_q, real_seqlen_k, ck_tile::AlibiMode::VERTICAL};
|
||||
}
|
||||
}();
|
||||
|
||||
ck_tile::HostTensor<SaccDataType> alibi_bias_host_ref(
|
||||
{nhead, real_seqlen_q, real_seqlen_k});
|
||||
auto i_b_slope = bias.rank_info == 0 ? 0 : wb;
|
||||
for(auto i_h = 0; i_h < nhead; i_h++)
|
||||
{
|
||||
SaccDataType current_slope = alibi_slope_host(i_b_slope, i_h);
|
||||
alibi_host.slope = current_slope;
|
||||
for(auto i_r = 0; i_r < real_seqlen_q; i_r++)
|
||||
{
|
||||
for(auto i_c = 0; i_c < real_seqlen_k; i_c++)
|
||||
{
|
||||
SaccDataType pixel = 0;
|
||||
alibi_host.update(pixel, i_r, i_c);
|
||||
alibi_bias_host_ref(i_h, i_r, i_c) = pixel;
|
||||
}
|
||||
}
|
||||
}
|
||||
// [nhead, real_seqlen_q, real_seqlen_k]
|
||||
ck_tile::reference_batched_elementwise<SMPLComputeDataType,
|
||||
SaccDataType,
|
||||
SMPLComputeDataType,
|
||||
SMPLComputeDataType>(
|
||||
s_host_ref, alibi_bias_host_ref, s_host_ref);
|
||||
}
|
||||
|
||||
if(mask.type == mask_enum::no_mask)
|
||||
{
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
#include "ck_tile/ops/fmha.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "mask.hpp"
|
||||
#include "bias.hpp"
|
||||
#include <type_traits>
|
||||
|
||||
template <typename DataType>
|
||||
@@ -86,7 +87,7 @@ struct fmha_fwd_args
|
||||
const void* q_ptr;
|
||||
const void* k_ptr;
|
||||
const void* v_ptr;
|
||||
const void* bias_ptr;
|
||||
const void* bias_ptr; // bias or alibi_slope pointer
|
||||
void* lse_ptr;
|
||||
void* o_ptr;
|
||||
const void* seqstart_q_ptr;
|
||||
@@ -106,7 +107,7 @@ struct fmha_fwd_args
|
||||
ck_tile::index_t stride_q;
|
||||
ck_tile::index_t stride_k;
|
||||
ck_tile::index_t stride_v;
|
||||
ck_tile::index_t stride_bias;
|
||||
ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0
|
||||
ck_tile::index_t stride_o;
|
||||
ck_tile::index_t nhead_stride_q;
|
||||
ck_tile::index_t nhead_stride_k;
|
||||
@@ -219,7 +220,7 @@ template <ck_tile::index_t HDim_,
|
||||
bool kIsVLayoutRowMajor_,
|
||||
ck_tile::BlockFmhaPipelineEnum FmhaPipelineEnum_,
|
||||
typename FmhaMask_,
|
||||
bool kHasBias_,
|
||||
ck_tile::BlockAttentionBiasEnum BiasEnum_,
|
||||
bool kStoreLse_,
|
||||
bool kDoFp8StaticQuant_,
|
||||
bool kPadS_,
|
||||
@@ -240,7 +241,7 @@ struct fmha_fwd_traits_
|
||||
static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_;
|
||||
static constexpr auto FmhaPipelineEnum = FmhaPipelineEnum_;
|
||||
using FmhaMask = ck_tile::remove_cvref_t<FmhaMask_>;
|
||||
static constexpr bool kHasBias = kHasBias_;
|
||||
static constexpr auto BiasEnum = BiasEnum_;
|
||||
static constexpr bool kStoreLse = kStoreLse_;
|
||||
static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
|
||||
static constexpr bool kPadS = kPadS_;
|
||||
@@ -261,7 +262,7 @@ struct fmha_fwd_traits
|
||||
bool is_group_mode;
|
||||
bool is_v_rowmajor;
|
||||
mask_enum mask_type;
|
||||
bool has_bias;
|
||||
bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum
|
||||
bool has_lse;
|
||||
bool do_fp8_static_quant;
|
||||
// TODO: padding check is inside this api
|
||||
|
||||
@@ -40,6 +40,19 @@ MASK_MAP = {
|
||||
"generic" : "FmhaMasks::GenericMask"
|
||||
}
|
||||
|
||||
BIAS_MAP = {
|
||||
"no" : "ck_tile::BlockAttentionBiasEnum::NO_BIAS",
|
||||
"bias" : "ck_tile::BlockAttentionBiasEnum::ELEMENTWISE_BIAS",
|
||||
"alibi" : "ck_tile::BlockAttentionBiasEnum::ALIBI"
|
||||
}
|
||||
|
||||
# TODO: this is ugly
|
||||
BIAS_CHECK_MAP = {
|
||||
"no" : "bias_enum::no_bias",
|
||||
"bias" : "bias_enum::elementwise_bias",
|
||||
"alibi" : "bias_enum::alibi"
|
||||
}
|
||||
|
||||
MODE_MAP = {
|
||||
"batch" : "false",
|
||||
"group" : "true"
|
||||
@@ -173,7 +186,7 @@ MASK_SIMPLIFIED_CHECK_MAP = {
|
||||
"s_mask" : "t.mask_type != mask_enum::no_mask",
|
||||
}
|
||||
|
||||
FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.has_bias == {F_bias}) && (t.has_lse == {F_lse}) && (t.do_fp8_static_quant == {F_squant}) &&
|
||||
FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.do_fp8_static_quant == {F_squant}) &&
|
||||
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
|
||||
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
|
||||
return fmha_fwd_<trait_>(s, a);
|
||||
@@ -213,7 +226,7 @@ class FmhaFwdApiTrait:
|
||||
bk0blen : int
|
||||
vlayout : str
|
||||
mask : str
|
||||
bias : str # true/false
|
||||
bias : str #
|
||||
lse : str #
|
||||
squant : str #
|
||||
spad : str
|
||||
@@ -241,8 +254,8 @@ class FmhaFwdApiTrait:
|
||||
def skcheck(self) -> str:
|
||||
if self.mode == 'group': return 'true/*group mode skpad always true*/' # group mode only generate spad/skpad == true
|
||||
if self.pipeline_tag == 'qr_async':
|
||||
if self.skpad == 't' : return f'a.seqlen_k % {self.bn0} != 0'
|
||||
else : return f'a.seqlen_k % {self.bn0} == 0'
|
||||
if self.skpad == 't' : return f'a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0'
|
||||
else : return f'a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0'
|
||||
elif self.pipeline_tag in ['qr', 'qr_fp8']:
|
||||
if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
|
||||
else : return f'a.seqlen_k % {self.bn0} == 0'
|
||||
@@ -297,7 +310,7 @@ class FmhaFwdPipeline:
|
||||
pn = pad_name()
|
||||
n = f'{self.tag}_v{self.F_vlayout[0]}'
|
||||
if pn != '' : n += f'_{pn}'
|
||||
if self.F_bias == 't' : n += '_bias'
|
||||
if self.F_bias != 'no' : n += f'_{self.F_bias}'
|
||||
if self.F_mask[0:2] == 's_':
|
||||
if self.F_mask == 's_mask': n += f'_mask'
|
||||
else:
|
||||
@@ -332,7 +345,8 @@ class FmhaFwdApiPool:
|
||||
if_k = 'if' if k == 0 else 'else if'
|
||||
inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout],
|
||||
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_mask=get_mask_map(self.mask_impl)[trait.mask],
|
||||
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias=BOOL_MAP[trait.bias], F_lse=BOOL_MAP[trait.lse],
|
||||
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias],
|
||||
F_lse=BOOL_MAP[trait.lse],
|
||||
F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
|
||||
F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
|
||||
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0blen=trait.bk0blen,
|
||||
@@ -400,7 +414,7 @@ class FmhaFwdKernel:
|
||||
F_skpad = BOOL_MAP[self.F_pipeline.F_skpad],
|
||||
F_dpad = BOOL_MAP[self.F_pipeline.F_dpad],
|
||||
F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad],
|
||||
F_bias = BOOL_MAP[self.F_pipeline.F_bias],
|
||||
F_bias = BIAS_MAP[self.F_pipeline.F_bias],
|
||||
F_lse = BOOL_MAP[self.F_pipeline.F_lse],
|
||||
F_squant = BOOL_MAP[self.F_pipeline.F_squant],
|
||||
F_occupancy = self.F_tile.F_occupancy,
|
||||
@@ -454,7 +468,9 @@ def get_fmha_fwd_tile_dict_from_dtype(direction : str, dtype : str) -> Optional[
|
||||
}
|
||||
elif dtype == 'fp8' or dtype == 'bf8':
|
||||
return {
|
||||
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 32, 32, 32, -1)
|
||||
'64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 32, 32, 32, -1),
|
||||
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 32, 32, 32, -1),
|
||||
'256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 32, 32, 32, -1)
|
||||
}
|
||||
else:
|
||||
return None
|
||||
@@ -472,7 +488,7 @@ def get_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFw
|
||||
squant = 't' if dtype == 'fp8' else 'f'
|
||||
pipelines = []
|
||||
if dtype in ['fp16', 'bf16']:
|
||||
for mask, bias, lse in itertools.product(get_mask_map(mask_impl).keys(), ["t", "f"], ["t", "f"]):
|
||||
for mask, bias, lse in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"]):
|
||||
if hdim == 256:
|
||||
# if True:
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, squant, mask))
|
||||
@@ -490,7 +506,7 @@ def get_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFw
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, squant, mask)) # TODO: cover arbitraty hdim
|
||||
elif dtype in ['fp8', 'bf8']:
|
||||
# no need lse kernels
|
||||
for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), ["t", "f"]):
|
||||
for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', squant, mask))
|
||||
else:
|
||||
assert False
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -149,11 +149,9 @@ struct mask_info
|
||||
return tmp;
|
||||
}
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, const mask_info& mi);
|
||||
friend std::ostream& operator<<(std::ostream& os, const mask_info& mi)
|
||||
{
|
||||
mi.serialize(os);
|
||||
return os;
|
||||
}
|
||||
};
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, const mask_info& mi)
|
||||
{
|
||||
mi.serialize(os);
|
||||
return os;
|
||||
}
|
||||
|
||||
@@ -17,7 +17,7 @@ for perm in 0 1 ; do
|
||||
for vlayout in "r" "c" ; do
|
||||
for hdim in 32 64 128 256 ; do
|
||||
for lse in 0 1 ; do
|
||||
for bias in 0 1 ; do
|
||||
for bias in "n" "e" "a"; do
|
||||
|
||||
# $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16, -d_v=$hdim -s=55 -s_k=256 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
|
||||
@@ -27,6 +27,7 @@ $EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1024 -s_k=256 -bias=$b
|
||||
$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=33 -s_k=0 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
|
||||
|
||||
done
|
||||
done
|
||||
@@ -37,9 +38,11 @@ done
|
||||
done
|
||||
|
||||
for perm in 0 1 ; do
|
||||
for bias in 0 1 ; do
|
||||
for bias in "n" "e" "a" ; do
|
||||
for b in 1 2 ; do
|
||||
for hdim in 64 128 256 ; do
|
||||
$EXE -prec=fp8 -init=3 -b=$b -h=1 -d=128 -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=c -squant=1 -kname=$KNAME $COMMON_ARGS
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
||||
|
||||
Reference in New Issue
Block a user