mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 19:40:04 +00:00
* add block scale parameters to kernel * add block scale to kernel * add smoke test * format * Revert "format" This reverts commit356c3c9706. * only format my code * format py * fix auto not allowd in function prototype * change instance tttt to ttff * fix structured binding issue * change s_acc elementwise op * async pipeline add block scale * add quantation P using shift exp2 * precompute (m - shift) once per row * change blk scale seqstrt ptr name * fix some name * fix for deduction guide * fix some comments * add P scale to qr_ksvs_pipeline * add comment to idx_identity * change the method of calculating descale block index * unify naming style: use block_scale_ as name prefix * unify naming style * update the CHANGELOG.md * Add FP8 block scale quantization support for FMHA forward kernel --------- Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com> [ROCm/composable_kernel commit:dd0b4294af]
61 lines
1.4 KiB
C++
61 lines
1.4 KiB
C++
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
// SPDX-License-Identifier: MIT
|
|
|
|
#pragma once
|
|
|
|
#include <ostream>
|
|
#include <string>
|
|
#include "ck_tile/core.hpp"
|
|
#include "ck_tile/ops/fmha.hpp"
|
|
|
|
// keep sync with BlockAttentionQuantScaleEnum
|
|
enum class quant_scale_enum
|
|
{
|
|
no_scale = 0,
|
|
pertensor = 1,
|
|
blockscale,
|
|
};
|
|
|
|
struct quant_scale_info
|
|
{
|
|
quant_scale_enum type;
|
|
|
|
void serialize(std::ostream& os) const
|
|
{
|
|
if(type == quant_scale_enum::no_scale)
|
|
os << "n";
|
|
else if(type == quant_scale_enum::pertensor)
|
|
os << "pt";
|
|
else if(type == quant_scale_enum::blockscale)
|
|
os << "bs";
|
|
}
|
|
|
|
static quant_scale_info decode(std::string str)
|
|
{
|
|
quant_scale_info info{quant_scale_enum::no_scale};
|
|
if(str == "n" || str == "0")
|
|
{
|
|
info.type = quant_scale_enum::no_scale;
|
|
}
|
|
else if(str == "pt" || str == "1")
|
|
{
|
|
info.type = quant_scale_enum::pertensor;
|
|
}
|
|
else if(str == "bs" || str == "2")
|
|
{
|
|
info.type = quant_scale_enum::blockscale;
|
|
}
|
|
else
|
|
{
|
|
throw std::invalid_argument("invalid quant scale value: " + str);
|
|
}
|
|
return info;
|
|
}
|
|
|
|
friend std::ostream& operator<<(std::ostream& os, const quant_scale_info& qsi)
|
|
{
|
|
qsi.serialize(os);
|
|
return os;
|
|
}
|
|
};
|