mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 00:40:09 +00:00
Merge remote-tracking branch 'origin/develop' into samremes/ck_tile_mx_gemm
This commit is contained in:
@@ -77,11 +77,13 @@ def get_mask_cpp_check_expr(mask: str) -> str:
|
||||
QSCALE_MAP = {
|
||||
"no": "ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE",
|
||||
"pertensor": "ck_tile::BlockAttentionQuantScaleEnum::PERTENSOR",
|
||||
"blockscale": "ck_tile::BlockAttentionQuantScaleEnum::BLOCKSCALE",
|
||||
}
|
||||
|
||||
QSCALE_CHECK_MAP = {
|
||||
"no": "quant_scale_enum::no_scale",
|
||||
"pertensor": "quant_scale_enum::pertensor",
|
||||
"blockscale": "quant_scale_enum::blockscale",
|
||||
}
|
||||
|
||||
BIAS_MAP = {
|
||||
|
||||
@@ -36,7 +36,7 @@ DTYPE_BITS = {
|
||||
|
||||
K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 256: 256}
|
||||
|
||||
SUPPORTED_PAGE_SIZE = [1, 128, 256, 1024]
|
||||
SUPPORTED_PAGE_SIZE = [1, 16, 1024]
|
||||
SUPPORTED_KV_MEMORY_LAYOUT = ["vectorized", "linear"]
|
||||
SUPPORTED_KV_LOOKUP_TABLE = ["vllm", "sglang"]
|
||||
KV_MEMORY_LAYOUT_ENUM_MAP = {
|
||||
@@ -630,6 +630,7 @@ class KernelComponentFactory:
|
||||
if dtype in ["fp16", "bf16"]:
|
||||
return {
|
||||
128 : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
256 : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
} # fmt: skip
|
||||
elif dtype in ["fp8bf16"]:
|
||||
return {
|
||||
|
||||
@@ -315,7 +315,7 @@ class FmhaFwdApiTrait:
|
||||
assert False
|
||||
|
||||
def seqtune(self, max_bm0: int) -> str:
|
||||
if self.bm0 == max_bm0:
|
||||
if self.bm0 == max_bm0 or self.bm0 == 64:
|
||||
return "true/*fall back to largest tile*/"
|
||||
else:
|
||||
return f"a.seqlen_q <= {self.bm0}"
|
||||
@@ -847,6 +847,11 @@ class CompatibilityRuleFactoryGfx9(CompatibilityRuleFactory):
|
||||
(problem_ctx.hdim, problem_ctx.hdim_v) != (128, 128)
|
||||
and kernel_ctx.tile.F_bm0 != 128
|
||||
)
|
||||
or (
|
||||
(problem_ctx.hdim, problem_ctx.hdim_v) == (128, 128)
|
||||
and kernel_ctx.pipeline.tag != "qr_async"
|
||||
and kernel_ctx.tile.F_bk0 == 64
|
||||
)
|
||||
):
|
||||
# non qr_async_trload only support km0=128 tile size when hdim is not 128
|
||||
# non qr_async only support kn0=128 tile size when hdim is 128
|
||||
@@ -942,6 +947,7 @@ class KernelComponentFactoryGfx9(CompatibilityRuleFactoryGfx9):
|
||||
( 96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
(128, 128) : [FmhaFwdTileSize( 16, 32, 64, 128, 32, 128, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1),
|
||||
FmhaFwdTileSize( 32, 32, 128, 128, 32, 128, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
FmhaFwdTileSize( 64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 32, 16, 16, 16, -1, CppConstraint('get_num_blocks(64) <= num_cus')),
|
||||
FmhaFwdTileSize(128, 64, 32, 128, 16, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
# (160, 160) : [FmhaFwdTileSize(128, 128 , 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)],
|
||||
@@ -1018,7 +1024,7 @@ class KernelComponentFactoryGfx9(CompatibilityRuleFactoryGfx9):
|
||||
# no need lse/dropout kernels
|
||||
for logits, qscale, mask, bias, sink in itertools.product(
|
||||
["t", "f"],
|
||||
["no", "pertensor"],
|
||||
["no", "pertensor", "blockscale"],
|
||||
get_mask_map(mask_impl).keys(),
|
||||
["no"],
|
||||
["f", "t"],
|
||||
@@ -1146,7 +1152,10 @@ class KernelComponentFactoryGfx12(CompatibilityRuleFactory):
|
||||
elif dtype in cls._DT_FP8_FP8BF16 or dtype in cls._DT_FP8FP32:
|
||||
# no need lse/dropout kernels
|
||||
for logits, qscale, mask, bias in itertools.product(
|
||||
["f"], ["no", "pertensor"], get_mask_map(mask_impl).keys(), ["no"]
|
||||
["f"],
|
||||
["no", "pertensor", "blockscale"],
|
||||
get_mask_map(mask_impl).keys(),
|
||||
["no"],
|
||||
):
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, "f", "f", qscale, mask, "f", "f", "f")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f", "f")) # fmt: skip
|
||||
|
||||
@@ -189,7 +189,7 @@ struct fmha_bwd_args
|
||||
ck_tile::index_t nhead_stride_randval;
|
||||
ck_tile::index_t nhead_stride_do;
|
||||
ck_tile::index_t nhead_stride_lsed;
|
||||
ck_tile::index_t nhead_stride_dq_acc;
|
||||
ck_tile::long_index_t nhead_stride_dq_acc;
|
||||
ck_tile::index_t nhead_stride_dq;
|
||||
ck_tile::index_t nhead_stride_dk;
|
||||
ck_tile::index_t nhead_stride_dv;
|
||||
@@ -202,7 +202,7 @@ struct fmha_bwd_args
|
||||
ck_tile::index_t batch_stride_randval;
|
||||
ck_tile::index_t batch_stride_do;
|
||||
ck_tile::index_t batch_stride_lsed;
|
||||
ck_tile::index_t batch_stride_dq_acc;
|
||||
ck_tile::long_index_t batch_stride_dq_acc;
|
||||
ck_tile::index_t batch_stride_dq;
|
||||
ck_tile::index_t batch_stride_dk;
|
||||
ck_tile::index_t batch_stride_dv;
|
||||
|
||||
@@ -287,9 +287,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
? get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, max_seqlen_k)
|
||||
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1} /* dummy shape for simplifying code */);
|
||||
ck_tile::HostTensor<AccDataType> dq_acc_host(
|
||||
i_perm
|
||||
? std::array<ck_tile::index_t, 5>{nsplits, shape_batch, nhead, shape_seqlen_q, hdim_q}
|
||||
: std::array<ck_tile::index_t, 5>{nsplits, shape_batch, shape_seqlen_q, nhead, hdim_q});
|
||||
std::array<ck_tile::index_t, 5>{shape_batch, nhead, nsplits, shape_seqlen_q, hdim_q});
|
||||
|
||||
if(init_method == "ui" || init_method == "0")
|
||||
{
|
||||
@@ -433,6 +431,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
const ck_tile::index_t stride_dk = (i_perm ? hdim_q : nhead * hdim_q);
|
||||
const ck_tile::index_t stride_dv = (i_perm ? hdim_v : nhead * hdim_v);
|
||||
const ck_tile::index_t stride_dbias = (i_perm ? max_seqlen_k : nhead * max_seqlen_k);
|
||||
const auto split_stride_dq_acc = (shape_seqlen_q * hdim_q);
|
||||
// setup nhead_stride_* arguments
|
||||
const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q);
|
||||
const ck_tile::index_t nhead_stride_k = (i_perm ? shape_seqlen_k * hdim_q : hdim_q);
|
||||
@@ -444,6 +443,8 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
const ck_tile::index_t nhead_stride_lsed = shape_seqlen_q;
|
||||
const ck_tile::index_t nhead_stride_dbias =
|
||||
(i_perm ? shape_seqlen_q * max_seqlen_k : max_seqlen_k);
|
||||
const auto nhead_stride_dq_acc =
|
||||
static_cast<ck_tile::long_index_t>(split_stride_dq_acc) * nsplits;
|
||||
// setup batch_stride_* arguments
|
||||
const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q);
|
||||
const ck_tile::index_t batch_stride_k = (nhead_k * shape_seqlen_k * hdim_q);
|
||||
@@ -456,8 +457,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
const ck_tile::index_t batch_stride_dk = (nhead * shape_seqlen_k * hdim_q);
|
||||
const ck_tile::index_t batch_stride_dv = (nhead * shape_seqlen_k * hdim_v);
|
||||
const ck_tile::index_t batch_stride_dbias = (nhead * shape_seqlen_q * max_seqlen_k);
|
||||
const ck_tile::index_t split_stride_dq_acc =
|
||||
(shape_batch * nhead * shape_seqlen_q * hdim_q);
|
||||
const auto batch_stride_dq_acc = nhead * nhead_stride_dq_acc;
|
||||
|
||||
const auto drop_seed_offset = [&]() -> decltype(fmha_bwd_args::drop_seed_offset) {
|
||||
if(drop_prefs)
|
||||
@@ -513,7 +513,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
stride_o,
|
||||
stride_randval,
|
||||
stride_do,
|
||||
stride_q, // stride_dq_acc
|
||||
hdim_q, // stride_dq_acc
|
||||
stride_q, // stride_dq
|
||||
stride_dk,
|
||||
stride_dv,
|
||||
@@ -526,7 +526,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
nhead_stride_randval,
|
||||
nhead_stride_do,
|
||||
nhead_stride_lsed,
|
||||
nhead_stride_q, // nhead_stride_dq_acc
|
||||
nhead_stride_dq_acc,
|
||||
nhead_stride_q, // nhead_stride_dq
|
||||
nhead_stride_k, // nhead_stride_dk
|
||||
nhead_stride_v, // nhead_stride_dv
|
||||
@@ -539,7 +539,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
batch_stride_randval,
|
||||
batch_stride_do,
|
||||
batch_stride_lsed,
|
||||
batch_stride_q, // batch_stride_dq_acc
|
||||
batch_stride_dq_acc,
|
||||
batch_stride_q, // batch_stride_dq
|
||||
batch_stride_dk,
|
||||
batch_stride_dv,
|
||||
|
||||
@@ -230,6 +230,8 @@ struct fmha_fwd_args
|
||||
// array [batch + 1]. (Used with padding)
|
||||
const void* cu_seqlen_k_ptr = nullptr; // Cumulative logical (excluding padding) sequence length
|
||||
// array [batch + 1]. (Used with padding)
|
||||
const void* block_scale_seqstart_q_ptr;
|
||||
const void* block_scale_seqstart_k_ptr;
|
||||
const void* sink_ptr;
|
||||
|
||||
ck_tile::index_t seqlen_q;
|
||||
@@ -257,6 +259,9 @@ struct fmha_fwd_args
|
||||
ck_tile::index_t nhead_stride_randval;
|
||||
ck_tile::index_t nhead_stride_lse;
|
||||
ck_tile::index_t nhead_stride_o;
|
||||
ck_tile::index_t nhead_stride_q_descale;
|
||||
ck_tile::index_t nhead_stride_k_descale;
|
||||
ck_tile::index_t nhead_stride_v_descale;
|
||||
ck_tile::index_t batch_stride_q;
|
||||
ck_tile::index_t batch_stride_k;
|
||||
ck_tile::index_t batch_stride_v;
|
||||
@@ -264,6 +269,9 @@ struct fmha_fwd_args
|
||||
ck_tile::index_t batch_stride_randval;
|
||||
ck_tile::index_t batch_stride_lse;
|
||||
ck_tile::index_t batch_stride_o;
|
||||
ck_tile::index_t batch_stride_q_descale;
|
||||
ck_tile::index_t batch_stride_k_descale;
|
||||
ck_tile::index_t batch_stride_v_descale;
|
||||
|
||||
ck_tile::index_t window_size_left;
|
||||
ck_tile::index_t window_size_right;
|
||||
@@ -276,6 +284,9 @@ struct fmha_fwd_args
|
||||
|
||||
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
|
||||
drop_seed_offset;
|
||||
|
||||
ck_tile::index_t block_scale_size_q;
|
||||
ck_tile::index_t block_scale_size_kv;
|
||||
};
|
||||
|
||||
struct fmha_fwd_pagedkv_args
|
||||
@@ -615,6 +626,8 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
|
||||
args.seqstart_k_ptr,
|
||||
args.seqlen_q_ptr,
|
||||
args.seqlen_k_ptr,
|
||||
args.block_scale_seqstart_q_ptr,
|
||||
args.block_scale_seqstart_k_ptr,
|
||||
args.hdim_q,
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
@@ -634,6 +647,9 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
|
||||
args.nhead_stride_randval,
|
||||
args.nhead_stride_lse,
|
||||
args.nhead_stride_o,
|
||||
args.nhead_stride_q_descale,
|
||||
args.nhead_stride_k_descale,
|
||||
args.nhead_stride_v_descale,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.sink_size,
|
||||
@@ -642,6 +658,8 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
|
||||
args.p_drop,
|
||||
args.s_randval,
|
||||
args.drop_seed_offset,
|
||||
args.block_scale_size_q,
|
||||
args.block_scale_size_kv,
|
||||
args.cu_seqlen_q_ptr,
|
||||
args.cu_seqlen_k_ptr,
|
||||
args.sink_ptr);
|
||||
@@ -679,6 +697,9 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
|
||||
args.nhead_stride_randval,
|
||||
args.nhead_stride_lse,
|
||||
args.nhead_stride_o,
|
||||
args.nhead_stride_q_descale,
|
||||
args.nhead_stride_k_descale,
|
||||
args.nhead_stride_v_descale,
|
||||
args.batch_stride_q,
|
||||
args.batch_stride_k,
|
||||
args.batch_stride_v,
|
||||
@@ -686,6 +707,9 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
|
||||
args.batch_stride_randval,
|
||||
args.batch_stride_lse,
|
||||
args.batch_stride_o,
|
||||
args.batch_stride_q_descale,
|
||||
args.batch_stride_k_descale,
|
||||
args.batch_stride_v_descale,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.sink_size,
|
||||
@@ -693,6 +717,8 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
|
||||
args.p_drop,
|
||||
args.s_randval,
|
||||
args.drop_seed_offset,
|
||||
args.block_scale_size_q,
|
||||
args.block_scale_size_kv,
|
||||
args.cu_seqlen_q_ptr,
|
||||
args.cu_seqlen_k_ptr,
|
||||
args.sink_ptr);
|
||||
|
||||
@@ -210,6 +210,11 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
const ck_tile::stream_config& stream_config,
|
||||
std::optional<std::string> json = std::nullopt)
|
||||
{
|
||||
// Note: block_scale_size_q_ and block_scale_size_kv_ should be greater than or equal to the
|
||||
// compute block size
|
||||
constexpr ck_tile::index_t block_scale_size_q_ = 128;
|
||||
constexpr ck_tile::index_t block_scale_size_kv_ = 128;
|
||||
|
||||
const std::string data_type = []() {
|
||||
if constexpr(std::is_same_v<DataTypeConfig, FmhaFwdFp32>)
|
||||
return "fp32";
|
||||
@@ -471,7 +476,11 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
std::size_t flop = 0, num_byte = 0;
|
||||
auto max_seqlen_q =
|
||||
std::numeric_limits<int32_t>::min(); // we will use max seqlen to decide grid size
|
||||
auto max_seqlen_k = std::numeric_limits<int32_t>::min();
|
||||
size_t i_block_scale_q = 0;
|
||||
size_t i_block_scale_k = 0;
|
||||
std::vector<int32_t> block_scale_seqstart_q_host = {0};
|
||||
std::vector<int32_t> block_scale_seqstart_k_host = {0};
|
||||
auto max_seqlen_k = std::numeric_limits<int32_t>::min();
|
||||
{
|
||||
for(ck_tile::index_t wb = 0; wb < batch; ++wb)
|
||||
{
|
||||
@@ -487,6 +496,10 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
{
|
||||
max_seqlen_k = real_seqlen_k;
|
||||
}
|
||||
i_block_scale_q += ck_tile::integer_divide_ceil(real_seqlen_q, block_scale_size_q_);
|
||||
i_block_scale_k += ck_tile::integer_divide_ceil(real_seqlen_k, block_scale_size_kv_);
|
||||
block_scale_seqstart_q_host.push_back(i_block_scale_q);
|
||||
block_scale_seqstart_k_host.push_back(i_block_scale_k);
|
||||
|
||||
flop += nhead * (static_cast<std::size_t>(2) * mask.get_unmaskarea() * hdim_q +
|
||||
static_cast<std::size_t>(2) * mask.get_unmaskarea() * hdim_v);
|
||||
@@ -548,6 +561,15 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
? seqstart_k_with_padding_host.back()
|
||||
: seqstart_k_host.back()));
|
||||
|
||||
const ck_tile::index_t num_block_scale_q =
|
||||
(mode == mode_enum::batch)
|
||||
? ck_tile::integer_divide_ceil(shape_seqlen_q, block_scale_size_q_)
|
||||
: i_block_scale_q;
|
||||
const ck_tile::index_t num_block_scale_kv =
|
||||
(mode == mode_enum::batch)
|
||||
? ck_tile::integer_divide_ceil(shape_seqlen_k, block_scale_size_kv_)
|
||||
: i_block_scale_k;
|
||||
|
||||
ck_tile::HostTensor<QDataType> q_host(
|
||||
get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q));
|
||||
ck_tile::HostTensor<SMPLComputeDataType> sink_host({nhead});
|
||||
@@ -599,9 +621,18 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
: std::array<ck_tile::index_t, 5>{1, 1, 1, 1, 1});
|
||||
|
||||
// TODO - change the tensor length for different quant scale
|
||||
ck_tile::HostTensor<float> q_descale_host(get_lengths(i_perm, 1, 1, 1, 1));
|
||||
ck_tile::HostTensor<float> k_descale_host(get_lengths(i_perm, 1, 1, 1, 1));
|
||||
ck_tile::HostTensor<float> v_descale_host(get_lengths(i_perm, 1, 1, 1, 1));
|
||||
ck_tile::HostTensor<float> q_descale_host(
|
||||
qscale.type == quant_scale_enum::blockscale
|
||||
? std::array<ck_tile::index_t, 3>{shape_batch, nhead, num_block_scale_q}
|
||||
: std::array<ck_tile::index_t, 3>{1, 1, 1});
|
||||
ck_tile::HostTensor<float> k_descale_host(
|
||||
qscale.type == quant_scale_enum::blockscale
|
||||
? std::array<ck_tile::index_t, 3>{shape_batch, nhead_k, num_block_scale_kv}
|
||||
: std::array<ck_tile::index_t, 3>{1, 1, 1});
|
||||
ck_tile::HostTensor<float> v_descale_host(
|
||||
qscale.type == quant_scale_enum::blockscale
|
||||
? std::array<ck_tile::index_t, 3>{shape_batch, nhead_k, num_block_scale_kv}
|
||||
: std::array<ck_tile::index_t, 3>{1, 1, 1});
|
||||
|
||||
// batch mode of lse data layout is [batch, nhead, seqlen_q]
|
||||
// group mode of lse data layout is [nhead, total_seqlen_q]
|
||||
@@ -717,6 +748,24 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
k_descale_host(0) = qkv_max / k_dtype_max;
|
||||
v_descale_host(0) = qkv_max / v_dtype_max;
|
||||
}
|
||||
else if(qscale.type == quant_scale_enum::blockscale)
|
||||
{
|
||||
float q_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<QDataType>::max());
|
||||
float k_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<KDataType>::max());
|
||||
float v_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<VDataType>::max());
|
||||
|
||||
float qkv_max = 3.f;
|
||||
float max_descale_q = qkv_max / q_dtype_max;
|
||||
float max_descale_k = qkv_max / k_dtype_max;
|
||||
float max_descale_v = qkv_max / v_dtype_max;
|
||||
|
||||
ck_tile::FillUniformDistribution<float>{max_descale_q * 0.8f, max_descale_q, next_seed()}(
|
||||
q_descale_host);
|
||||
ck_tile::FillUniformDistribution<float>{max_descale_k * 0.8f, max_descale_k, next_seed()}(
|
||||
k_descale_host);
|
||||
ck_tile::FillUniformDistribution<float>{max_descale_v * 0.8f, max_descale_v, next_seed()}(
|
||||
v_descale_host);
|
||||
}
|
||||
|
||||
iota_shuffle(block_table_host.begin(), block_table_host.end(), 0, random_engine);
|
||||
iota_shuffle(cache_batch_idx_host.begin(), cache_batch_idx_host.end(), 0, random_engine);
|
||||
@@ -737,6 +786,10 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
ck_tile::DeviceMem q_descale_buf(q_descale_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem k_descale_buf(k_descale_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem v_descale_buf(v_descale_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem block_scale_seqstart_q_buf(block_scale_seqstart_q_host.size() *
|
||||
sizeof(int32_t));
|
||||
ck_tile::DeviceMem block_scale_seqstart_k_buf(block_scale_seqstart_k_host.size() *
|
||||
sizeof(int32_t));
|
||||
ck_tile::DeviceMem lse_acc_buf(lse_acc_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem o_acc_buf(o_acc_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem lse_buf(lse_host.get_element_space_size_in_bytes());
|
||||
@@ -782,6 +835,8 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
q_descale_buf.ToDevice(q_descale_host.data());
|
||||
k_descale_buf.ToDevice(k_descale_host.data());
|
||||
v_descale_buf.ToDevice(v_descale_host.data());
|
||||
block_scale_seqstart_q_buf.ToDevice(block_scale_seqstart_q_host.data());
|
||||
block_scale_seqstart_k_buf.ToDevice(block_scale_seqstart_k_host.data());
|
||||
seqstart_q.ToDevice(seqstart_q_host.data());
|
||||
// Keep logical starts in seqstart_k; pass padded K via separate pointer
|
||||
seqstart_k.ToDevice(seqstart_k_host.data());
|
||||
@@ -975,11 +1030,14 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
}();
|
||||
const ck_tile::index_t nhead_stride_bias =
|
||||
(i_perm ? 0 * shape_seqlen_q * max_seqlen_k : 0 * max_seqlen_k);
|
||||
const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k);
|
||||
const ck_tile::index_t nhead_stride_lse = shape_seqlen_q;
|
||||
const ck_tile::index_t nhead_stride_lse_acc = (num_splits * shape_seqlen_q);
|
||||
const ck_tile::index_t nhead_stride_o_acc = (num_splits * shape_seqlen_q * hdim_v);
|
||||
const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
|
||||
const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k);
|
||||
const ck_tile::index_t nhead_stride_lse = shape_seqlen_q;
|
||||
const ck_tile::index_t nhead_stride_lse_acc = (num_splits * shape_seqlen_q);
|
||||
const ck_tile::index_t nhead_stride_o_acc = (num_splits * shape_seqlen_q * hdim_v);
|
||||
const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
|
||||
const ck_tile::index_t nhead_stride_q_descale = num_block_scale_q;
|
||||
const ck_tile::index_t nhead_stride_k_descale = num_block_scale_kv;
|
||||
const ck_tile::index_t nhead_stride_v_descale = num_block_scale_kv;
|
||||
// setup batch_stride_* arguments
|
||||
const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q);
|
||||
const ck_tile::index_t batch_stride_k =
|
||||
@@ -997,6 +1055,9 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
const ck_tile::index_t batch_stride_o_acc = (nhead * num_splits * shape_seqlen_q * hdim_v);
|
||||
const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v);
|
||||
const ck_tile::index_t batch_stride_block_table = (max_num_page_blocks / batch);
|
||||
const ck_tile::index_t batch_stride_q_descale = num_block_scale_q * nhead;
|
||||
const ck_tile::index_t batch_stride_k_descale = num_block_scale_kv * nhead_k;
|
||||
const ck_tile::index_t batch_stride_v_descale = num_block_scale_kv * nhead_k;
|
||||
// setup split_stride_* arguments (only used in split-kv kernel)
|
||||
const ck_tile::index_t split_stride_lse_acc = (shape_seqlen_q);
|
||||
const ck_tile::index_t split_stride_o_acc = (shape_seqlen_q * hdim_v);
|
||||
@@ -1084,9 +1145,39 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
|
||||
if constexpr(std::is_same_v<fmha_fwd_args, std::decay_t<decltype(args)>>)
|
||||
{
|
||||
args.q_descale_ptr = q_descale_buf.GetDeviceBuffer();
|
||||
args.k_descale_ptr = k_descale_buf.GetDeviceBuffer();
|
||||
args.v_descale_ptr = v_descale_buf.GetDeviceBuffer();
|
||||
if(qscale.type == quant_scale_enum::blockscale)
|
||||
{
|
||||
args.q_descale_ptr =
|
||||
reinterpret_cast<const float*>(q_descale_buf.GetDeviceBuffer());
|
||||
args.k_descale_ptr =
|
||||
reinterpret_cast<const float*>(k_descale_buf.GetDeviceBuffer());
|
||||
args.v_descale_ptr =
|
||||
reinterpret_cast<const float*>(v_descale_buf.GetDeviceBuffer());
|
||||
|
||||
args.block_scale_seqstart_q_ptr =
|
||||
(mode == mode_enum::group ? block_scale_seqstart_q_buf.GetDeviceBuffer()
|
||||
: nullptr);
|
||||
args.block_scale_seqstart_k_ptr =
|
||||
(mode == mode_enum::group ? block_scale_seqstart_k_buf.GetDeviceBuffer()
|
||||
: nullptr);
|
||||
|
||||
args.nhead_stride_q_descale = nhead_stride_q_descale;
|
||||
args.nhead_stride_k_descale = nhead_stride_k_descale;
|
||||
args.nhead_stride_v_descale = nhead_stride_v_descale;
|
||||
|
||||
args.batch_stride_q_descale = batch_stride_q_descale;
|
||||
args.batch_stride_k_descale = batch_stride_k_descale;
|
||||
args.batch_stride_v_descale = batch_stride_v_descale;
|
||||
|
||||
args.block_scale_size_q = block_scale_size_q_;
|
||||
args.block_scale_size_kv = block_scale_size_kv_;
|
||||
}
|
||||
else
|
||||
{
|
||||
args.q_descale_ptr = q_descale_buf.GetDeviceBuffer();
|
||||
args.k_descale_ptr = k_descale_buf.GetDeviceBuffer();
|
||||
args.v_descale_ptr = v_descale_buf.GetDeviceBuffer();
|
||||
}
|
||||
|
||||
args.rand_val_ptr = randval_buf.GetDeviceBuffer();
|
||||
|
||||
@@ -1589,14 +1680,42 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
#endif
|
||||
|
||||
// reference
|
||||
ck_tile::
|
||||
reference_batched_gemm<QDataType, KDataType, SaccDataType, SMPLComputeDataType>(
|
||||
if(qscale.type == quant_scale_enum::blockscale)
|
||||
{
|
||||
const ck_tile::index_t q_offset =
|
||||
(mode == mode_enum::batch) ? 0 : block_scale_seqstart_q_host[wb];
|
||||
const ck_tile::index_t k_offset =
|
||||
(mode == mode_enum::batch) ? 0 : block_scale_seqstart_k_host[wb];
|
||||
ck_tile::reference_batched_quant_gemm<QDataType,
|
||||
KDataType,
|
||||
SaccDataType,
|
||||
SMPLComputeDataType>(
|
||||
q_host_ref,
|
||||
k_host_ref,
|
||||
s_host_ref,
|
||||
ck_tile::identity{},
|
||||
ck_tile::identity{},
|
||||
ck_tile::scales(scale_s_host));
|
||||
ck_tile::idx_identity{},
|
||||
ck_tile::idx_identity{},
|
||||
[&](auto idx, auto value) {
|
||||
return value * scale_s *
|
||||
q_descale_host(b_idx,
|
||||
std::get<0>(idx),
|
||||
q_offset + std::get<1>(idx) / block_scale_size_q_) *
|
||||
k_descale_host(b_idx,
|
||||
std::get<0>(idx) / nr,
|
||||
k_offset + std::get<2>(idx) / block_scale_size_kv_);
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::
|
||||
reference_batched_gemm<QDataType, KDataType, SaccDataType, SMPLComputeDataType>(
|
||||
q_host_ref,
|
||||
k_host_ref,
|
||||
s_host_ref,
|
||||
ck_tile::identity{},
|
||||
ck_tile::identity{},
|
||||
ck_tile::scales(scale_s_host));
|
||||
}
|
||||
|
||||
if(0.f < logits_soft_cap)
|
||||
{
|
||||
@@ -1794,13 +1913,35 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
}
|
||||
}
|
||||
|
||||
ck_tile::reference_batched_gemm<PDataType, VDataType, OaccDataType, ODataType>(
|
||||
p_host_ref,
|
||||
v_host_ref,
|
||||
o_host_ref,
|
||||
ck_tile::identity{},
|
||||
ck_tile::identity{},
|
||||
oacc_element_func);
|
||||
if(qscale.type == quant_scale_enum::blockscale)
|
||||
{
|
||||
const ck_tile::index_t v_offset =
|
||||
(mode == mode_enum::batch) ? 0 : block_scale_seqstart_k_host[wb];
|
||||
ck_tile::
|
||||
reference_batched_quant_gemm<PDataType, VDataType, OaccDataType, ODataType>(
|
||||
p_host_ref,
|
||||
v_host_ref,
|
||||
o_host_ref,
|
||||
ck_tile::idx_identity{},
|
||||
[&](auto idx, auto value) {
|
||||
return ck_tile::type_convert<float>(value) *
|
||||
v_descale_host(b_idx,
|
||||
std::get<0>(idx) / nr,
|
||||
v_offset +
|
||||
std::get<2>(idx) / block_scale_size_kv_);
|
||||
},
|
||||
ck_tile::idx_identity{});
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::reference_batched_gemm<PDataType, VDataType, OaccDataType, ODataType>(
|
||||
p_host_ref,
|
||||
v_host_ref,
|
||||
o_host_ref,
|
||||
ck_tile::identity{},
|
||||
ck_tile::identity{},
|
||||
oacc_element_func);
|
||||
}
|
||||
|
||||
ck_tile::HostTensor<ODataType> o_host_result({nhead, real_seqlen_q, hdim_v});
|
||||
// clang-format off
|
||||
@@ -1808,7 +1949,6 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
if(o_perm) o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b_idx, idx[0], idx[1] + query_offset, idx[2]); });
|
||||
else o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b_idx, idx[1] + query_offset, idx[0], idx[2]); });
|
||||
// clang-format on
|
||||
|
||||
auto [rtol, atol] = get_elimit<DataTypeConfig>(init_method);
|
||||
bool cur_pass = ck_tile::check_err(o_host_result,
|
||||
o_host_ref,
|
||||
@@ -1866,31 +2006,33 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
|
||||
if(json)
|
||||
{
|
||||
dump_fmha_fwd_json_results(*json,
|
||||
data_type,
|
||||
mode == mode_enum::batch ? "batch" : "group",
|
||||
io_layout(i_perm, o_perm),
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
seqlen_qs[0],
|
||||
seqlen_ks[0],
|
||||
seqlen_kpads[0],
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
scale_s,
|
||||
p_drop,
|
||||
lse,
|
||||
qscale.type == quant_scale_enum::no_scale ? "no_scale"
|
||||
: "pertensor",
|
||||
bias.type == bias_enum::elementwise_bias
|
||||
? "elementwise_bias"
|
||||
: (bias.type == bias_enum::alibi ? "alibi" : "no_bias"),
|
||||
is_v_rowmajor ? "r" : "c",
|
||||
pass,
|
||||
ave_time,
|
||||
tflops,
|
||||
gb_per_sec);
|
||||
dump_fmha_fwd_json_results(
|
||||
*json,
|
||||
data_type,
|
||||
mode == mode_enum::batch ? "batch" : "group",
|
||||
io_layout(i_perm, o_perm),
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
seqlen_qs[0],
|
||||
seqlen_ks[0],
|
||||
seqlen_kpads[0],
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
scale_s,
|
||||
p_drop,
|
||||
lse,
|
||||
qscale.type == quant_scale_enum::no_scale
|
||||
? "no_scale"
|
||||
: (qscale.type == quant_scale_enum::pertensor ? "pertensor" : "blockscale"),
|
||||
bias.type == bias_enum::elementwise_bias
|
||||
? "elementwise_bias"
|
||||
: (bias.type == bias_enum::alibi ? "alibi" : "no_bias"),
|
||||
is_v_rowmajor ? "r" : "c",
|
||||
pass,
|
||||
ave_time,
|
||||
tflops,
|
||||
gb_per_sec);
|
||||
}
|
||||
|
||||
return pass ? fwd_result::success : fwd_result::failure;
|
||||
|
||||
@@ -13,6 +13,7 @@ enum class quant_scale_enum
|
||||
{
|
||||
no_scale = 0,
|
||||
pertensor = 1,
|
||||
blockscale,
|
||||
};
|
||||
|
||||
struct quant_scale_info
|
||||
@@ -25,6 +26,8 @@ struct quant_scale_info
|
||||
os << "n";
|
||||
else if(type == quant_scale_enum::pertensor)
|
||||
os << "pt";
|
||||
else if(type == quant_scale_enum::blockscale)
|
||||
os << "bs";
|
||||
}
|
||||
|
||||
static quant_scale_info decode(std::string str)
|
||||
@@ -38,6 +41,10 @@ struct quant_scale_info
|
||||
{
|
||||
info.type = quant_scale_enum::pertensor;
|
||||
}
|
||||
else if(str == "bs" || str == "2")
|
||||
{
|
||||
info.type = quant_scale_enum::blockscale;
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::invalid_argument("invalid quant scale value: " + str);
|
||||
|
||||
@@ -95,10 +95,11 @@ run_fp8bf16_tests() {
|
||||
for perm in 0 1 ; do
|
||||
for b in 1 2 ; do
|
||||
for hdim in 64 128 256 ; do
|
||||
for scale in 1 2; do
|
||||
|
||||
$EXE -prec=fp8bf16 -init=3 -b=$b -h=1 -d=$hdim -s=128 -iperm=$perm -operm=$perm -vlayout=r -qscale=1 -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=fp8bf16 -init=3 -b=$b -h=1 -d=$hdim -s=128 -iperm=$perm -operm=$perm -vlayout=r -qscale=$scale -kname=$KNAME $COMMON_ARGS
|
||||
|
||||
done ; done ; done
|
||||
done ; done ; done ; done
|
||||
}
|
||||
|
||||
run_fp8fp32_tests() {
|
||||
|
||||
@@ -456,7 +456,8 @@ inline auto create_args()
|
||||
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
|
||||
.insert("jsonfile", "gemm.json", "json file name to dump results")
|
||||
.insert("flush_cache", "true", "flush cache before running the kernel, defaults to true")
|
||||
.insert("rotating_count", "1000", "rotating count, defaults to 1000");
|
||||
.insert("rotating_count", "1000", "rotating count, defaults to 1000")
|
||||
.insert("test_async", "0", "0: normal gemm, 1: test async input scheduler");
|
||||
return arg_parser;
|
||||
}
|
||||
|
||||
|
||||
@@ -12,6 +12,169 @@
|
||||
#include "run_gemm_example_common.hpp"
|
||||
#include "universal_gemm_invoker.hpp"
|
||||
|
||||
// Universal GEMM-specific wrapper that handles test_async flag
|
||||
template <typename GemmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType = ADataType,
|
||||
typename CDataType = ADataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
int run_gemm_example_with_layouts_universal(ck_tile::ArgParser& arg_parser,
|
||||
const ALayout a_layout = ALayout{},
|
||||
const BLayout b_layout = BLayout{},
|
||||
const CLayout c_layout = CLayout{})
|
||||
{
|
||||
using Invoker = UniversalInvoker;
|
||||
using AccDataType = typename GemmTypeConfig<ADataType, BDataType, CDataType>::AccDataType;
|
||||
|
||||
// Check for async input scheduler test mode
|
||||
bool test_async = arg_parser.get_int("test_async");
|
||||
if(test_async)
|
||||
{
|
||||
// Extract parameters for async test (same as shared implementation)
|
||||
const ck_tile::index_t M = arg_parser.get_int("m");
|
||||
const ck_tile::index_t N = arg_parser.get_int("n");
|
||||
const ck_tile::index_t K = arg_parser.get_int("k");
|
||||
const ck_tile::index_t kbatch = arg_parser.get_int("split_k");
|
||||
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
constexpr bool is_a_row_major = std::is_same_v<ALayout, Row>;
|
||||
constexpr bool is_b_row_major = std::is_same_v<BLayout, Row>;
|
||||
constexpr bool is_c_row_major = std::is_same_v<CLayout, Row>;
|
||||
|
||||
const ck_tile::index_t stride_A = is_a_row_major ? K : M;
|
||||
const ck_tile::index_t stride_B = is_b_row_major ? N : K;
|
||||
const ck_tile::index_t stride_C = is_c_row_major ? N : M;
|
||||
|
||||
// Allocate and initialize tensors
|
||||
ck_tile::HostTensor<ADataType> a_m_k(ck_tile::host_tensor_descriptor(
|
||||
M, K, stride_A, ck_tile::bool_constant<is_a_row_major>{}));
|
||||
ck_tile::HostTensor<BDataType> b_k_n(ck_tile::host_tensor_descriptor(
|
||||
K, N, stride_B, ck_tile::bool_constant<is_b_row_major>{}));
|
||||
ck_tile::HostTensor<CDataType> c_m_n_dev_result(ck_tile::host_tensor_descriptor(
|
||||
M, N, stride_C, ck_tile::bool_constant<is_c_row_major>{}));
|
||||
|
||||
ck_tile::FillUniformDistributionIntegerValue<ADataType>{-5, 5}(a_m_k);
|
||||
ck_tile::FillUniformDistributionIntegerValue<BDataType>{-5, 5}(b_k_n);
|
||||
|
||||
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
|
||||
|
||||
a_m_k_dev_buf.ToDevice(a_m_k.data());
|
||||
b_k_n_dev_buf.ToDevice(b_k_n.data());
|
||||
c_m_n_dev_buf.SetZero();
|
||||
c_m_n_dev_result.SetZero();
|
||||
|
||||
ck_tile::GemmHostArgs args = {a_m_k_dev_buf.GetDeviceBuffer(),
|
||||
b_k_n_dev_buf.GetDeviceBuffer(),
|
||||
c_m_n_dev_buf.GetDeviceBuffer(),
|
||||
kbatch,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C};
|
||||
|
||||
Invoker::template test_async_input_scheduler<GemmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough>(
|
||||
args, ck_tile::stream_config{nullptr, false, 1});
|
||||
|
||||
// Copy result from device for verification
|
||||
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
|
||||
|
||||
// Compute CPU reference
|
||||
ck_tile::HostTensor<CDataType> c_m_n_ref(ck_tile::host_tensor_descriptor(
|
||||
M, N, stride_C, ck_tile::bool_constant<is_c_row_major>{}));
|
||||
c_m_n_ref.SetZero();
|
||||
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
|
||||
a_m_k, b_k_n, c_m_n_ref);
|
||||
|
||||
// Verify results
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_ref.mData.begin(), c_m_n_ref.mData.end());
|
||||
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
K, kbatch, max_accumulated_value);
|
||||
bool pass = do_verify(c_m_n_dev_result, c_m_n_ref, rtol_atol, "CPU");
|
||||
|
||||
std::cout << "Async input scheduler test: " << (pass ? "PASS" : "FAIL") << std::endl;
|
||||
return pass;
|
||||
}
|
||||
|
||||
// Normal path - delegate to shared implementation
|
||||
return run_gemm_example_with_layouts<GemmConfig, Invoker, ADataType, BDataType, CDataType>(
|
||||
arg_parser, a_layout, b_layout, c_layout);
|
||||
}
|
||||
|
||||
// Universal GEMM-specific prec_type dispatcher that uses the wrapper
|
||||
template <typename GemmConfig,
|
||||
typename APrecType,
|
||||
typename BPrecType = APrecType,
|
||||
typename CPrecType = APrecType>
|
||||
int run_gemm_example_prec_type_universal(std::string a_layout,
|
||||
std::string b_layout,
|
||||
ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
bool preshuffle = GemmConfig::Preshuffle;
|
||||
|
||||
if(preshuffle && std::is_same_v<BPrecType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
throw std::runtime_error("Preshuffle is not supported for this int4 datatype!");
|
||||
}
|
||||
|
||||
if(preshuffle && a_layout != "R" && b_layout != "C")
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"Preshuffle is supported only for A(Row major), B(column major) input matrices!");
|
||||
}
|
||||
|
||||
using LayoutVariant = std::variant<Row, Col>;
|
||||
|
||||
auto string_to_layout = [](const std::string& layout) -> LayoutVariant {
|
||||
if(layout == "R")
|
||||
return Row{};
|
||||
if(layout == "C")
|
||||
return Col{};
|
||||
throw std::runtime_error("Unsupported layout: " + layout);
|
||||
};
|
||||
|
||||
auto a_layout_variant = string_to_layout(a_layout);
|
||||
auto b_layout_variant = string_to_layout(b_layout);
|
||||
|
||||
return std::visit(
|
||||
[&](auto a_layout_type, auto b_layout_type) -> int {
|
||||
if constexpr(std::is_same_v<BPrecType, ck_tile::pk_int4_t> &&
|
||||
std::is_same_v<decltype(b_layout_type), Row>)
|
||||
{
|
||||
throw std::runtime_error("Unsupported memory layout for the input matrices when "
|
||||
"BPrecType is ck_tile::pk_int4_t!");
|
||||
}
|
||||
else
|
||||
{
|
||||
return run_gemm_example_with_layouts_universal<GemmConfig,
|
||||
APrecType,
|
||||
BPrecType,
|
||||
CPrecType>(
|
||||
arg_parser, a_layout_type, b_layout_type, Row{});
|
||||
}
|
||||
},
|
||||
a_layout_variant,
|
||||
b_layout_variant);
|
||||
}
|
||||
|
||||
template <template <typename PrecType> typename GemmConfig>
|
||||
int run_gemm_example(ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
@@ -19,52 +182,50 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser)
|
||||
std::string a_layout = arg_parser.get_str("a_layout");
|
||||
std::string b_layout = arg_parser.get_str("b_layout");
|
||||
|
||||
using Invoker = UniversalInvoker;
|
||||
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t>, Invoker, ck_tile::half_t>(
|
||||
return run_gemm_example_prec_type_universal<GemmConfig<ck_tile::half_t>, ck_tile::half_t>(
|
||||
a_layout, b_layout, arg_parser);
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf16_t>, Invoker, ck_tile::bf16_t>(
|
||||
return run_gemm_example_prec_type_universal<GemmConfig<ck_tile::bf16_t>, ck_tile::bf16_t>(
|
||||
a_layout, b_layout, arg_parser);
|
||||
}
|
||||
else if(data_type == "fp8")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
Invoker,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::half_t>(a_layout, b_layout, arg_parser);
|
||||
return run_gemm_example_prec_type_universal<GemmConfig<ck_tile::fp8_t>,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::half_t>(
|
||||
a_layout, b_layout, arg_parser);
|
||||
}
|
||||
else if(data_type == "bf8")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
Invoker,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::half_t>(a_layout, b_layout, arg_parser);
|
||||
return run_gemm_example_prec_type_universal<GemmConfig<ck_tile::bf8_t>,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::half_t>(
|
||||
a_layout, b_layout, arg_parser);
|
||||
}
|
||||
else if(data_type == "int8")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::int8_t>,
|
||||
Invoker,
|
||||
ck_tile::int8_t,
|
||||
ck_tile::int8_t,
|
||||
ck_tile::int32_t>(a_layout, b_layout, arg_parser);
|
||||
return run_gemm_example_prec_type_universal<GemmConfig<ck_tile::int8_t>,
|
||||
ck_tile::int8_t,
|
||||
ck_tile::int8_t,
|
||||
ck_tile::int32_t>(
|
||||
a_layout, b_layout, arg_parser);
|
||||
}
|
||||
else if(data_type == "fp16i4")
|
||||
{
|
||||
// TODO: Add support for bhalf_t ADataType
|
||||
if constexpr(GemmConfig<ck_tile::half_t>::Pipeline == ck_tile::GemmPipeline::COMPUTE_V3)
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t>,
|
||||
Invoker,
|
||||
ck_tile::half_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t>(a_layout, b_layout, arg_parser);
|
||||
return run_gemm_example_prec_type_universal<GemmConfig<ck_tile::half_t>,
|
||||
ck_tile::half_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t>(
|
||||
a_layout, b_layout, arg_parser);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -75,11 +236,11 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
if constexpr(GemmConfig<ck_tile::fp8_t>::Pipeline == ck_tile::GemmPipeline::COMPUTE_V3)
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
Invoker,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t>(a_layout, b_layout, arg_parser);
|
||||
return run_gemm_example_prec_type_universal<GemmConfig<ck_tile::fp8_t>,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t>(
|
||||
a_layout, b_layout, arg_parser);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -90,11 +251,11 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
if constexpr(GemmConfig<ck_tile::bf8_t>::Pipeline == ck_tile::GemmPipeline::COMPUTE_V3)
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
Invoker,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t>(a_layout, b_layout, arg_parser);
|
||||
return run_gemm_example_prec_type_universal<GemmConfig<ck_tile::bf8_t>,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t>(
|
||||
a_layout, b_layout, arg_parser);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -2,7 +2,11 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
#pragma once
|
||||
#include <functional>
|
||||
#include <chrono>
|
||||
#include <thread>
|
||||
#include "gemm_utils.hpp"
|
||||
#include "ck_tile/host/hip_check_error.hpp"
|
||||
#include "ck_tile/host/device_memory.hpp"
|
||||
|
||||
struct UniversalInvoker
|
||||
{
|
||||
@@ -150,4 +154,170 @@ struct UniversalInvoker
|
||||
preprocess,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename CDEElementWise>
|
||||
static void test_async_input_scheduler(const ck_tile::GemmHostArgs& args,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
|
||||
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
|
||||
ck_tile::
|
||||
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>,
|
||||
GemmConfig::PermuteA,
|
||||
GemmConfig::PermuteB>;
|
||||
|
||||
using TilePartitioner =
|
||||
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
|
||||
GemmConfig::TileParitionerGroupNum,
|
||||
GemmConfig::TileParitionerM01>;
|
||||
|
||||
using GemmUniversalTraits =
|
||||
ck_tile::TileGemmUniversalTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
GemmConfig::DoubleSmemBuffer,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ELayout,
|
||||
GemmConfig::TransposeC,
|
||||
GemmConfig::UseStructuredSparsity,
|
||||
true, // Persistent = true for async test
|
||||
GemmConfig::NumWaveGroups,
|
||||
GemmConfig::Preshuffle>;
|
||||
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler>;
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
GemmConfig::NumWaveGroups,
|
||||
false, /*FixedVectorSize_*/
|
||||
1, /*VectorSizeC_*/
|
||||
false, /*TiledMMAPermuteN_*/
|
||||
1, /*BlockedXDLN_PerWarp_*/
|
||||
GemmConfig::DoubleSmemBuffer>>;
|
||||
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
|
||||
const ck_tile::index_t tiles_m =
|
||||
ck_tile::integer_divide_ceil(args.M, TilePartitioner::MPerBlock);
|
||||
// Balance signal granularity (smaller chunks = finer control) vs overhead (more signals)
|
||||
const ck_tile::index_t tiles_per_chunk = 2;
|
||||
// Shift chunk assignments to test wraparound behavior
|
||||
const ck_tile::index_t tile_idx_pivot = tiles_per_chunk;
|
||||
// Account for pivot when allocating signal buffer
|
||||
const ck_tile::index_t num_chunks =
|
||||
ck_tile::integer_divide_ceil(tiles_m + tile_idx_pivot, tiles_per_chunk);
|
||||
|
||||
std::cout << "Async Input Scheduler Test:" << std::endl;
|
||||
std::cout << " M tiles: " << tiles_m << std::endl;
|
||||
std::cout << " Tiles per chunk: " << tiles_per_chunk << std::endl;
|
||||
std::cout << " Tile index pivot: " << tile_idx_pivot << std::endl;
|
||||
std::cout << " Number of signal chunks: " << num_chunks << std::endl;
|
||||
|
||||
// Signals must start as zero so kernel blocks until producer sets them
|
||||
ck_tile::DeviceMem signal_buf(num_chunks * sizeof(uint32_t));
|
||||
signal_buf.SetZero();
|
||||
uint32_t* d_chunk_signals = static_cast<uint32_t*>(signal_buf.GetDeviceBuffer());
|
||||
|
||||
// Setup async input scheduler
|
||||
ck_tile::PersistentAsyncInputScheduler async_scheduler;
|
||||
async_scheduler.tiles_per_chunk_m = tiles_per_chunk;
|
||||
async_scheduler.chunk_signals = d_chunk_signals;
|
||||
async_scheduler.tile_idx_pivot_m = tile_idx_pivot;
|
||||
async_scheduler.num_chunks = num_chunks;
|
||||
|
||||
// Create modified host args with async scheduler
|
||||
ck_tile::UniversalGemmHostArgs<1, 1, 0> host_args({args.a_ptr},
|
||||
{args.b_ptr},
|
||||
{},
|
||||
args.e_ptr,
|
||||
args.k_batch,
|
||||
args.M,
|
||||
args.N,
|
||||
args.K,
|
||||
{args.stride_A},
|
||||
{args.stride_B},
|
||||
{},
|
||||
args.stride_E,
|
||||
async_scheduler);
|
||||
|
||||
auto kargs = Kernel::UniversalGemmKernel::MakeKernelArgs(host_args);
|
||||
|
||||
const dim3 grids = Kernel::MaxOccupancyGridSize(s);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
std::cout << " Grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< std::endl;
|
||||
std::cout << " Blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
|
||||
// Separate stream prevents deadlock: kernel and signal producer must run concurrently
|
||||
hipStream_t signal_stream;
|
||||
HIP_CHECK_ERROR(hipStreamCreateWithFlags(&signal_stream, hipStreamNonBlocking));
|
||||
|
||||
const auto start = std::chrono::high_resolution_clock::now();
|
||||
|
||||
ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
|
||||
// Simulate incremental input arrival by delaying signal activation
|
||||
const int sleep_us = 100;
|
||||
for(ck_tile::index_t i = 0; i < num_chunks; ++i)
|
||||
{
|
||||
std::this_thread::sleep_for(std::chrono::microseconds(sleep_us));
|
||||
const uint32_t signal_val = 1;
|
||||
HIP_CHECK_ERROR(hipMemcpyAsync(d_chunk_signals + i,
|
||||
&signal_val,
|
||||
sizeof(uint32_t),
|
||||
hipMemcpyHostToDevice,
|
||||
signal_stream));
|
||||
}
|
||||
HIP_CHECK_ERROR(hipStreamSynchronize(signal_stream));
|
||||
HIP_CHECK_ERROR(hipStreamDestroy(signal_stream));
|
||||
|
||||
// Wait for kernel completion
|
||||
HIP_CHECK_ERROR(hipDeviceSynchronize());
|
||||
|
||||
auto duration = std::chrono::duration_cast<std::chrono::microseconds>(
|
||||
std::chrono::high_resolution_clock::now() - start);
|
||||
|
||||
std::cout << " Total time: " << duration.count() << " us" << std::endl;
|
||||
std::cout << " Sleep time: " << (num_chunks * sleep_us) << " us" << std::endl;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -59,7 +59,8 @@ float grouped_gemm_abquant(const std::vector<grouped_gemm_kargs>& gemm_descs,
|
||||
using GemmUniversalTraits = ck_tile::TileGemmQuantTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
false, // PreshuffleQuant
|
||||
false, // APreshuffleQuant
|
||||
false, // BPreshuffleQuant
|
||||
GemmConfig::PreshuffleB,
|
||||
ALayout,
|
||||
BLayout,
|
||||
@@ -202,7 +203,8 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
using GemmUniversalTraits = ck_tile::TileGemmQuantTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
false, // PreshuffleQuant
|
||||
false, // APreshuffleQuant
|
||||
false, // BPreshuffleQuant
|
||||
GemmConfig::PreshuffleB,
|
||||
ALayout,
|
||||
BLayout,
|
||||
|
||||
@@ -44,7 +44,8 @@ float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
|
||||
using GemmUniversalTraits = ck_tile::TileGemmQuantTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
false, // PreshuffleQuant
|
||||
false, // APreshuffleQuant
|
||||
false, // BPreshuffleQuant
|
||||
GemmConfig::PreshuffleB,
|
||||
ALayout,
|
||||
BLayout,
|
||||
@@ -210,7 +211,8 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
using GemmUniversalTraits = ck_tile::TileGemmQuantTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
false, // PreshuffleQuant
|
||||
false, // APreshuffleQuant
|
||||
false, // BPreshuffleQuant
|
||||
GemmConfig::PreshuffleB,
|
||||
ALayout,
|
||||
BLayout,
|
||||
|
||||
@@ -257,6 +257,24 @@ struct ConvTypeConfig<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t>
|
||||
template <ck_tile::GemmPipeline PipelineId>
|
||||
struct PipelineTypeTraits;
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<ck_tile::GemmPipeline::BASIC_V1>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<PipelineProblem>;
|
||||
template <typename PipelineProblem>
|
||||
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAGmemBGmemCRegV1<PipelineProblem>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<ck_tile::GemmPipeline::BASIC_V2>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV2<PipelineProblem>;
|
||||
template <typename PipelineProblem>
|
||||
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAGmemBGmemCRegV2<PipelineProblem>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<ck_tile::GemmPipeline::MEMORY>
|
||||
{
|
||||
|
||||
@@ -6,6 +6,7 @@ if(CK_USE_OCP_FP8)
|
||||
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
|
||||
endif()
|
||||
|
||||
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -Wno-global-constructors) # use global constructors to add kernel instances
|
||||
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0)
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
|
||||
@@ -20,9 +21,18 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
|
||||
gemm_bquant_quantgrouped_bf16mxfp4.cpp
|
||||
gemm_bquant_quantgrouped_bf8.cpp
|
||||
gemm_bquant_quantgrouped_fp8.cpp
|
||||
gemm_bquant_quantgrouped_preshuffleb.cpp
|
||||
gemm_bquant_quantgrouped_preshufflequant.cpp
|
||||
gemm_bquant_quantgrouped_preshuffleb_preshufflequant.cpp
|
||||
gemm_bquant_quantgrouped_preshuffleb_bf8i4.cpp
|
||||
gemm_bquant_quantgrouped_preshuffleb_fp8i4.cpp
|
||||
gemm_bquant_quantgrouped_preshuffleb_bf8.cpp
|
||||
gemm_bquant_quantgrouped_preshuffleb_fp8.cpp
|
||||
gemm_bquant_quantgrouped_preshufflequant_bf8i4.cpp
|
||||
gemm_bquant_quantgrouped_preshufflequant_fp8i4.cpp
|
||||
gemm_bquant_quantgrouped_preshufflequant_bf8.cpp
|
||||
gemm_bquant_quantgrouped_preshufflequant_fp8.cpp
|
||||
gemm_bquant_quantgrouped_preshuffleb_preshufflequant_bf8i4.cpp
|
||||
gemm_bquant_quantgrouped_preshuffleb_preshufflequant_fp8i4.cpp
|
||||
gemm_bquant_quantgrouped_preshuffleb_preshufflequant_bf8.cpp
|
||||
gemm_bquant_quantgrouped_preshuffleb_preshufflequant_fp8.cpp
|
||||
gemm_quant_rowcol.cpp
|
||||
gemm_quant_tensor.cpp
|
||||
)
|
||||
|
||||
@@ -4,11 +4,16 @@
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigQuantPrefill<T>;
|
||||
using GemmConfig = GemmConfigABQuantPrefill<T>;
|
||||
|
||||
void abquant_quantgrouped_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
{
|
||||
template <typename T>
|
||||
using GemmConfigPreshuffleB = GemmConfigPreshuffleB_ABQuant_Prefill<T>;
|
||||
|
||||
// template <typename T>
|
||||
// using GemmConfigPreshuffleB = GemmConfigPreshuffleB_ABQuant_Decode<T>;
|
||||
|
||||
static auto _ = []() {
|
||||
auto& lut = get_kernel_lut();
|
||||
lut[hash_multiple_strings({"fp8",
|
||||
"abquant",
|
||||
"non-preshuffleb",
|
||||
@@ -78,7 +83,7 @@ void abquant_quantgrouped_instance_factory(
|
||||
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfigPreshuffleB_BQuant_Prefill<ck_tile::fp8_t>,
|
||||
return run_gemm_example_prec_type<GemmConfigPreshuffleB<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
@@ -93,7 +98,7 @@ void abquant_quantgrouped_instance_factory(
|
||||
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfigPreshuffleB_BQuant_Prefill<ck_tile::fp8_t>,
|
||||
return run_gemm_example_prec_type<GemmConfigPreshuffleB<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
@@ -108,7 +113,7 @@ void abquant_quantgrouped_instance_factory(
|
||||
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfigPreshuffleB_BQuant_Prefill<ck_tile::bf8_t>,
|
||||
return run_gemm_example_prec_type<GemmConfigPreshuffleB<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
@@ -123,10 +128,41 @@ void abquant_quantgrouped_instance_factory(
|
||||
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfigPreshuffleB_BQuant_Prefill<ck_tile::bf8_t>,
|
||||
return run_gemm_example_prec_type<GemmConfigPreshuffleB<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
|
||||
};
|
||||
}
|
||||
lut[hash_multiple_strings({"fp8",
|
||||
"abquant",
|
||||
"non-preshuffleb",
|
||||
"preshufflequant",
|
||||
"1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using AQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfigPreshuffleBQuantPrefill<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8",
|
||||
"abquant",
|
||||
"non-preshuffleb",
|
||||
"preshufflequant",
|
||||
"1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using AQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfigPreshuffleBQuantPrefill<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
|
||||
};
|
||||
return 0;
|
||||
}();
|
||||
|
||||
@@ -4,15 +4,14 @@
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigQuantDecode<T>;
|
||||
using GemmConfig = GemmConfigQuantDecodeInterwave<T>;
|
||||
|
||||
// GemmConfigQuantPrefill is also supported for aquant grouped quantization
|
||||
// template <typename T>
|
||||
// using GemmConfig = GemmConfigQuantPrefill<T>;
|
||||
|
||||
void aquant_quantgrouped_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
{
|
||||
static auto _ = []() {
|
||||
auto& lut = get_kernel_lut();
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8", "aquant", "non-preshufflequant", "1x1x128"})] = [](const ck_tile::ArgParser&
|
||||
@@ -56,4 +55,5 @@ void aquant_quantgrouped_instance_factory(
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::AQuantGrouped>(arg_parser);
|
||||
};
|
||||
}
|
||||
return 0;
|
||||
}();
|
||||
|
||||
@@ -6,9 +6,8 @@
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigPreshuffleQuantDecode<T>;
|
||||
|
||||
void aquant_quantgrouped_preshufflequant_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
{
|
||||
static auto _ = []() {
|
||||
auto& lut = get_kernel_lut();
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8", "aquant", "preshufflequant", "1x1x128"})] = [](const ck_tile::ArgParser&
|
||||
@@ -52,4 +51,5 @@ void aquant_quantgrouped_preshufflequant_instance_factory(
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::AQuantGrouped>(arg_parser);
|
||||
};
|
||||
}
|
||||
return 0;
|
||||
}();
|
||||
|
||||
@@ -12,9 +12,8 @@ using GemmConfig = GemmConfigQuantPrefill<T>;
|
||||
QuantGroupSize, \
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
|
||||
void bquant_quantgrouped_bf16fp4_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
{
|
||||
static auto _ = []() {
|
||||
auto& lut = get_kernel_lut();
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf16_t,
|
||||
ck_tile::pk_fp4_raw_t,
|
||||
ck_tile::bf16_t,
|
||||
@@ -38,4 +37,5 @@ void bquant_quantgrouped_bf16fp4_instance_factory(
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
}
|
||||
return 0;
|
||||
}();
|
||||
|
||||
@@ -12,9 +12,8 @@ using GemmConfig = GemmConfigQuantPrefill<T>;
|
||||
QuantGroupSize, \
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
|
||||
void bquant_quantgrouped_bf8_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
{
|
||||
static auto _ = []() {
|
||||
auto& lut = get_kernel_lut();
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
|
||||
#ifndef CK_GFX950_SUPPORT
|
||||
@@ -49,4 +48,11 @@ void bquant_quantgrouped_bf8_instance_factory(
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
}
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8", "bquant", "non-preshuffleb", "non-preshufflequant", "1x128x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
return 0;
|
||||
}();
|
||||
|
||||
@@ -12,9 +12,8 @@ using GemmConfig = GemmConfigQuantPrefill<T>;
|
||||
QuantGroupSize, \
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
|
||||
void bquant_quantgrouped_bf8i4_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
{
|
||||
static auto _ = []() {
|
||||
auto& lut = get_kernel_lut();
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
@@ -51,4 +50,11 @@ void bquant_quantgrouped_bf8i4_instance_factory(
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
}
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8i4", "bquant", "non-preshuffleb", "non-preshufflequant", "1x128x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
return 0;
|
||||
}();
|
||||
|
||||
@@ -12,9 +12,8 @@ using GemmConfig = GemmConfigQuantPrefill<T>;
|
||||
QuantGroupSize, \
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
|
||||
void bquant_quantgrouped_fp8_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
{
|
||||
static auto _ = []() {
|
||||
auto& lut = get_kernel_lut();
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
#ifndef CK_GFX950_SUPPORT
|
||||
@@ -49,4 +48,11 @@ void bquant_quantgrouped_fp8_instance_factory(
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
}
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8", "bquant", "non-preshuffleb", "non-preshufflequant", "1x128x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
return 0;
|
||||
}();
|
||||
|
||||
@@ -12,9 +12,8 @@ using GemmConfig = GemmConfigQuantPrefill<T>;
|
||||
QuantGroupSize, \
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
|
||||
void bquant_quantgrouped_fp8i4_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
{
|
||||
static auto _ = []() {
|
||||
auto& lut = get_kernel_lut();
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
@@ -51,4 +50,11 @@ void bquant_quantgrouped_fp8i4_instance_factory(
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
}
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8i4", "bquant", "non-preshuffleb", "non-preshufflequant", "1x128x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
return 0;
|
||||
}();
|
||||
|
||||
@@ -1,222 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
#if CK_TILE_USE_WMMA
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigPreshuffleB_BQuant_Prefill_Wmma<T>;
|
||||
#else
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigPreshuffleB_BQuant_Prefill<T>;
|
||||
#endif
|
||||
|
||||
void bquant_quantgrouped_preshuffleb_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
{
|
||||
lut[hash_multiple_strings({"fp8", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::half_t,
|
||||
float>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8", "bquant", "preshuffleb", "non-preshufflequant", "1x8x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::half_t,
|
||||
float>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8",
|
||||
"bquant",
|
||||
"preshuffleb",
|
||||
"non-preshufflequant",
|
||||
"1x32x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8",
|
||||
"bquant",
|
||||
"preshuffleb",
|
||||
"non-preshufflequant",
|
||||
"1x64x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
|
||||
lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::half_t,
|
||||
float>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "non-preshufflequant", "1x8x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::half_t,
|
||||
float>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8",
|
||||
"bquant",
|
||||
"preshuffleb",
|
||||
"non-preshufflequant",
|
||||
"1x32x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8",
|
||||
"bquant",
|
||||
"preshuffleb",
|
||||
"non-preshufflequant",
|
||||
"1x64x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::fp8_t>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x8x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::fp8_t>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x32x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::fp8_t>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x64x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::fp8_t>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::bf8_t>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x8x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::bf8_t>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x32x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::bf8_t>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x64x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::bf8_t>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,53 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
#if CK_TILE_USE_WMMA
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigPreshuffleB_BQuant_Prefill_Wmma<T>;
|
||||
#else
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigPreshuffleB_BQuant_Prefill<T>;
|
||||
#endif
|
||||
|
||||
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
|
||||
run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>, \
|
||||
TypeConfig, \
|
||||
QuantGroupSize, \
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
|
||||
static auto _ = []() {
|
||||
auto& lut = get_kernel_lut();
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
|
||||
lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "non-preshufflequant", "1x8x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8", "bquant", "preshuffleb", "non-preshufflequant", "1x32x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8", "bquant", "preshuffleb", "non-preshufflequant", "1x64x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8", "bquant", "preshuffleb", "non-preshufflequant", "1x128x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
return 0;
|
||||
}();
|
||||
@@ -0,0 +1,57 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
#if CK_TILE_USE_WMMA
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigPreshuffleB_BQuant_Prefill_Wmma<T>;
|
||||
#else
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigPreshuffleB_BQuant_Prefill<T>;
|
||||
#endif
|
||||
|
||||
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
|
||||
run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>, \
|
||||
TypeConfig, \
|
||||
QuantGroupSize, \
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
|
||||
static auto _ = []() {
|
||||
auto& lut = get_kernel_lut();
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::bf8_t>{});
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x8x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x32x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x64x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x128x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
return 0;
|
||||
}();
|
||||
@@ -0,0 +1,53 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
#if CK_TILE_USE_WMMA
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigPreshuffleB_BQuant_Prefill_Wmma<T>;
|
||||
#else
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigPreshuffleB_BQuant_Prefill<T>;
|
||||
#endif
|
||||
|
||||
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
|
||||
run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, \
|
||||
TypeConfig, \
|
||||
QuantGroupSize, \
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
|
||||
static auto _ = []() {
|
||||
auto& lut = get_kernel_lut();
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
lut[hash_multiple_strings({"fp8", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8", "bquant", "preshuffleb", "non-preshufflequant", "1x8x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8", "bquant", "preshuffleb", "non-preshufflequant", "1x32x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8", "bquant", "preshuffleb", "non-preshufflequant", "1x64x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8", "bquant", "preshuffleb", "non-preshufflequant", "1x128x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
return 0;
|
||||
}();
|
||||
@@ -0,0 +1,57 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
#if CK_TILE_USE_WMMA
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigPreshuffleB_BQuant_Prefill_Wmma<T>;
|
||||
#else
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigPreshuffleB_BQuant_Prefill<T>;
|
||||
#endif
|
||||
|
||||
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
|
||||
run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, \
|
||||
TypeConfig, \
|
||||
QuantGroupSize, \
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
|
||||
static auto _ = []() {
|
||||
auto& lut = get_kernel_lut();
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::fp8_t>{});
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x8x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x32x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x64x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x128x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
return 0;
|
||||
}();
|
||||
@@ -1,62 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
#if CK_TILE_USE_WMMA
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigPreshuffleB_PreshuffleBQuant_Prefill_Wmma<T>;
|
||||
#else
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigPreshuffleB_PreshuffleBQuant_Prefill<T>;
|
||||
#endif
|
||||
|
||||
void bquant_quantgrouped_preshuffleb_preshufflequant_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
{
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
lut[hash_multiple_strings({"fp8", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::half_t,
|
||||
float>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::half_t,
|
||||
float>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8i4", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::fp8_t>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8i4", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::bf8_t>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,50 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
#if CK_TILE_USE_WMMA
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigPreshuffleB_PreshuffleBQuant_Prefill_Wmma<T>;
|
||||
#else
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigPreshuffleB_PreshuffleBQuant_Prefill<T>;
|
||||
#endif
|
||||
|
||||
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
|
||||
run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>, \
|
||||
TypeConfig, \
|
||||
QuantGroupSize, \
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
|
||||
static auto _ = []() {
|
||||
auto& lut = get_kernel_lut();
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
|
||||
lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "preshufflequant", "1x8x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "preshufflequant", "1x32x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "preshufflequant", "1x64x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "preshufflequant", "1x128x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
return 0;
|
||||
}();
|
||||
@@ -0,0 +1,52 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
#if CK_TILE_USE_WMMA
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigPreshuffleB_PreshuffleBQuant_Prefill_Wmma<T>;
|
||||
#else
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigPreshuffleB_PreshuffleBQuant_Prefill<T>;
|
||||
#endif
|
||||
|
||||
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
|
||||
run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>, \
|
||||
TypeConfig, \
|
||||
QuantGroupSize, \
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
|
||||
static auto _ = []() {
|
||||
auto& lut = get_kernel_lut();
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::bf8_t>{});
|
||||
lut[hash_multiple_strings({"bf8i4", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8i4", "bquant", "preshuffleb", "preshufflequant", "1x8x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8i4", "bquant", "preshuffleb", "preshufflequant", "1x32x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8i4", "bquant", "preshuffleb", "preshufflequant", "1x64x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8i4", "bquant", "preshuffleb", "preshufflequant", "1x128x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
return 0;
|
||||
}();
|
||||
@@ -0,0 +1,50 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
#if CK_TILE_USE_WMMA
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigPreshuffleB_PreshuffleBQuant_Prefill_Wmma<T>;
|
||||
#else
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigPreshuffleB_PreshuffleBQuant_Prefill<T>;
|
||||
#endif
|
||||
|
||||
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
|
||||
run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, \
|
||||
TypeConfig, \
|
||||
QuantGroupSize, \
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
|
||||
static auto _ = []() {
|
||||
auto& lut = get_kernel_lut();
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
lut[hash_multiple_strings({"fp8", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8", "bquant", "preshuffleb", "preshufflequant", "1x8x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8", "bquant", "preshuffleb", "preshufflequant", "1x32x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8", "bquant", "preshuffleb", "preshufflequant", "1x64x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8", "bquant", "preshuffleb", "preshufflequant", "1x128x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
return 0;
|
||||
}();
|
||||
@@ -0,0 +1,52 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
#if CK_TILE_USE_WMMA
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigPreshuffleB_PreshuffleBQuant_Prefill_Wmma<T>;
|
||||
#else
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigPreshuffleB_PreshuffleBQuant_Prefill<T>;
|
||||
#endif
|
||||
|
||||
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
|
||||
run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, \
|
||||
TypeConfig, \
|
||||
QuantGroupSize, \
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
|
||||
static auto _ = []() {
|
||||
auto& lut = get_kernel_lut();
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::fp8_t>{});
|
||||
lut[hash_multiple_strings({"fp8i4", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8i4", "bquant", "preshuffleb", "preshufflequant", "1x8x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8i4", "bquant", "preshuffleb", "preshufflequant", "1x32x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8i4", "bquant", "preshuffleb", "preshufflequant", "1x64x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8i4", "bquant", "preshuffleb", "preshufflequant", "1x128x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
return 0;
|
||||
}();
|
||||
@@ -1,270 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigPreshuffleBQuantPrefill<T>;
|
||||
|
||||
void bquant_quantgrouped_preshufflequant_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
{
|
||||
lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::half_t,
|
||||
float>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
|
||||
lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "preshufflequant", "1x8x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::half_t,
|
||||
float>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8",
|
||||
"bquant",
|
||||
"non-preshuffleb",
|
||||
"preshufflequant",
|
||||
"1x16x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 16, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8",
|
||||
"bquant",
|
||||
"non-preshuffleb",
|
||||
"preshufflequant",
|
||||
"1x32x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8",
|
||||
"bquant",
|
||||
"non-preshuffleb",
|
||||
"preshufflequant",
|
||||
"1x64x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
|
||||
lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::half_t,
|
||||
float>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "preshufflequant", "1x8x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::half_t,
|
||||
float>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8",
|
||||
"bquant",
|
||||
"non-preshuffleb",
|
||||
"preshufflequant",
|
||||
"1x16x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 16, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8",
|
||||
"bquant",
|
||||
"non-preshuffleb",
|
||||
"preshufflequant",
|
||||
"1x32x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8",
|
||||
"bquant",
|
||||
"non-preshuffleb",
|
||||
"preshufflequant",
|
||||
"1x64x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::fp8_t>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x8x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::fp8_t>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x16x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::fp8_t>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 16, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x32x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::fp8_t>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x64x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::fp8_t>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::bf8_t>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x8x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::bf8_t>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x16x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::bf8_t>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 16, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x32x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::bf8_t>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x64x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::bf8_t>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,55 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigPreshuffleBQuantPrefill<T>;
|
||||
|
||||
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
|
||||
run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>, \
|
||||
TypeConfig, \
|
||||
QuantGroupSize, \
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
|
||||
static auto _ = []() {
|
||||
auto& lut = get_kernel_lut();
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
|
||||
lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
|
||||
lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "preshufflequant", "1x8x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8", "bquant", "non-preshuffleb", "preshufflequant", "1x16x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 16, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8", "bquant", "non-preshuffleb", "preshufflequant", "1x32x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8", "bquant", "non-preshuffleb", "preshufflequant", "1x64x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8", "bquant", "non-preshuffleb", "preshufflequant", "1x128x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
return 0;
|
||||
}();
|
||||
@@ -0,0 +1,59 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigPreshuffleBQuantPrefill<T>;
|
||||
|
||||
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
|
||||
run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>, \
|
||||
TypeConfig, \
|
||||
QuantGroupSize, \
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
|
||||
static auto _ = []() {
|
||||
auto& lut = get_kernel_lut();
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::bf8_t>{});
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x8x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x16x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 16, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x32x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x64x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x128x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
return 0;
|
||||
}();
|
||||
@@ -0,0 +1,55 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigPreshuffleBQuantPrefill<T>;
|
||||
|
||||
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
|
||||
run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, \
|
||||
TypeConfig, \
|
||||
QuantGroupSize, \
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
|
||||
static auto _ = []() {
|
||||
auto& lut = get_kernel_lut();
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
|
||||
lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "preshufflequant", "1x8x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8", "bquant", "non-preshuffleb", "preshufflequant", "1x16x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 16, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8", "bquant", "non-preshuffleb", "preshufflequant", "1x32x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8", "bquant", "non-preshuffleb", "preshufflequant", "1x64x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8", "bquant", "non-preshuffleb", "preshufflequant", "1x128x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
return 0;
|
||||
}();
|
||||
@@ -0,0 +1,59 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigPreshuffleBQuantPrefill<T>;
|
||||
|
||||
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
|
||||
run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, \
|
||||
TypeConfig, \
|
||||
QuantGroupSize, \
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
|
||||
static auto _ = []() {
|
||||
auto& lut = get_kernel_lut();
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::fp8_t>{});
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x8x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x16x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 16, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x32x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x64x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x128x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
return 0;
|
||||
}();
|
||||
@@ -95,33 +95,6 @@ auto gen_lut_key(const ck_tile::ArgParser& arg_parser)
|
||||
return hash_multiple_strings(params);
|
||||
}
|
||||
|
||||
void abquant_quantgrouped_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void aquant_quantgrouped_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void aquant_quantgrouped_preshufflequant_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void bquant_quantgrouped_fp8_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void bquant_quantgrouped_bf8_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void bquant_quantgrouped_fp8i4_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void bquant_quantgrouped_bf8i4_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void bquant_quantgrouped_bf16fp4_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void bquant_quantgrouped_preshuffleb_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void bquant_quantgrouped_preshufflequant_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void bquant_quantgrouped_preshuffleb_preshufflequant_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void quant_rowcol_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void quant_tensor_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
@@ -135,20 +108,8 @@ int main(int argc, char* argv[])
|
||||
std::cout << "Device ID: " << device_id << std::endl;
|
||||
ck_tile::hip_check_error(hipSetDevice(device_id));
|
||||
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>> lut;
|
||||
abquant_quantgrouped_instance_factory(lut);
|
||||
aquant_quantgrouped_instance_factory(lut);
|
||||
aquant_quantgrouped_preshufflequant_instance_factory(lut);
|
||||
bquant_quantgrouped_fp8_instance_factory(lut);
|
||||
bquant_quantgrouped_bf8_instance_factory(lut);
|
||||
bquant_quantgrouped_fp8i4_instance_factory(lut);
|
||||
bquant_quantgrouped_bf8i4_instance_factory(lut);
|
||||
bquant_quantgrouped_bf16fp4_instance_factory(lut);
|
||||
bquant_quantgrouped_preshuffleb_instance_factory(lut);
|
||||
bquant_quantgrouped_preshufflequant_instance_factory(lut);
|
||||
bquant_quantgrouped_preshuffleb_preshufflequant_instance_factory(lut);
|
||||
quant_rowcol_instance_factory(lut);
|
||||
quant_tensor_instance_factory(lut);
|
||||
auto& lut = get_kernel_lut();
|
||||
std::cout << "Available kernels: " << lut.size() << std::endl;
|
||||
|
||||
auto key = gen_lut_key(arg_parser);
|
||||
|
||||
|
||||
@@ -6,9 +6,8 @@
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigQuantDecode<T>;
|
||||
|
||||
void quant_rowcol_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
{
|
||||
static auto _ = []() {
|
||||
auto& lut = get_kernel_lut();
|
||||
// NOTE: QuantGroupSize is a place holder. rowcol pipeline does not use QuantGroupSize
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 1>>;
|
||||
lut[hash_multiple_strings({"fp8", "rowcol"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
@@ -27,4 +26,5 @@ void quant_rowcol_instance_factory(
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::RowColQuant>(arg_parser);
|
||||
};
|
||||
}
|
||||
return 0;
|
||||
}();
|
||||
|
||||
@@ -6,9 +6,8 @@
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigQuantDecode<T>;
|
||||
|
||||
void quant_tensor_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
{
|
||||
static auto _ = []() {
|
||||
auto& lut = get_kernel_lut();
|
||||
// NOTE: QuantGroupSize is a place holder. tensor pipeline does not use QuantGroupSize
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 1>>;
|
||||
lut[hash_multiple_strings({"fp8", "tensor"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
@@ -27,4 +26,5 @@ void quant_tensor_instance_factory(
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::TensorQuant>(arg_parser);
|
||||
};
|
||||
}
|
||||
return 0;
|
||||
}();
|
||||
|
||||
@@ -11,6 +11,14 @@
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/gemm_quant.hpp"
|
||||
|
||||
inline auto& get_kernel_lut()
|
||||
{
|
||||
// In an inline function, function-local static objects in all function definitions are shared
|
||||
// across all translation units.
|
||||
static std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>> lut;
|
||||
return lut;
|
||||
}
|
||||
|
||||
inline size_t hash_multiple_strings(const std::vector<std::string>& inputs)
|
||||
{
|
||||
std::hash<std::string> hasher;
|
||||
@@ -72,7 +80,8 @@ struct GemmConfigBase
|
||||
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
static constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
|
||||
static constexpr bool PreshuffleQuant = false;
|
||||
static constexpr bool APreshuffleQuant = false;
|
||||
static constexpr bool BPreshuffleQuant = false;
|
||||
static constexpr bool PreshuffleB = false;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr bool TiledMMAPermuteN = false;
|
||||
@@ -93,6 +102,27 @@ struct GemmConfigQuantDecode : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
// static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigQuantDecodeInterwave : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Tile = 64;
|
||||
static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
@@ -128,7 +158,8 @@ struct GemmConfigPreshuffleQuantDecode : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
|
||||
|
||||
static constexpr bool PreshuffleQuant = true;
|
||||
static constexpr bool APreshuffleQuant = true;
|
||||
static constexpr bool BPreshuffleQuant = true;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
@@ -158,7 +189,7 @@ template <typename PrecType>
|
||||
struct GemmConfigPreshuffleB_PreshuffleBQuant_Decode
|
||||
: public GemmConfigPreshuffleB_BQuant_Decode<PrecType>
|
||||
{
|
||||
static constexpr bool PreshuffleQuant = true;
|
||||
static constexpr bool BPreshuffleQuant = true;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
@@ -189,7 +220,29 @@ template <typename PrecType>
|
||||
struct GemmConfigPreshuffleB_PreshuffleBQuant_Prefill
|
||||
: public GemmConfigPreshuffleB_BQuant_Prefill<PrecType>
|
||||
{
|
||||
static constexpr bool PreshuffleQuant = true;
|
||||
static constexpr bool BPreshuffleQuant = true;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigPreshuffleB_ABQuant_Prefill : public GemmConfigPreshuffleB_BQuant_Prefill<PrecType>
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Warp = 2;
|
||||
static constexpr ck_tile::index_t N_Warp = 2;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr bool kPadK = false;
|
||||
static constexpr bool TransposeC = true;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigPreshuffleB_ABQuant_Decode : public GemmConfigPreshuffleB_BQuant_Prefill<PrecType>
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType);
|
||||
|
||||
static constexpr bool kPadK = false;
|
||||
static constexpr bool TransposeC = true;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
@@ -207,12 +260,21 @@ struct GemmConfigQuantPrefill : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
// static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigABQuantPrefill : public GemmConfigQuantPrefill<PrecType>
|
||||
{
|
||||
static constexpr bool kPadK = false;
|
||||
static constexpr bool TransposeC = true;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigPreshuffleBQuantPrefill : public GemmConfigQuantPrefill<PrecType>
|
||||
{
|
||||
static constexpr bool PreshuffleQuant = true;
|
||||
static constexpr bool BPreshuffleQuant = true;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
|
||||
@@ -33,6 +33,8 @@ template <typename GemmConfig,
|
||||
float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
static_assert(std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::RowMajor>);
|
||||
constexpr bool transpose_c =
|
||||
GemmConfig::TransposeC; // QuantMode == ck_tile::QuantType::ABQuantGrouped;
|
||||
using ComputeDataType = std::conditional_t<QuantMode == ck_tile::QuantType::AQuantGrouped ||
|
||||
QuantMode == ck_tile::QuantType::RowColQuant,
|
||||
typename TypeConfig::BDataType,
|
||||
@@ -49,15 +51,16 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
|
||||
using GemmTraits = ck_tile::TileGemmQuantTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
GemmConfig::PreshuffleQuant,
|
||||
GemmConfig::APreshuffleQuant,
|
||||
GemmConfig::BPreshuffleQuant,
|
||||
GemmConfig::PreshuffleB,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
QuantMode,
|
||||
AQLayout, // for AQLayout
|
||||
BQLayout, // for BQLayout
|
||||
false,
|
||||
AQLayout,
|
||||
BQLayout,
|
||||
transpose_c,
|
||||
GemmConfig::DoubleSmemBuffer>;
|
||||
|
||||
using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase<typename TypeConfig::ADataType,
|
||||
@@ -72,23 +75,21 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
|
||||
GemmConfig::PreshuffleB == true,
|
||||
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<GemmPipelineProblem>,
|
||||
std::conditional_t<
|
||||
QuantMode == ck_tile::QuantType::AQuantGrouped && GemmConfig::PreshuffleQuant == true,
|
||||
QuantMode == ck_tile::QuantType::AQuantGrouped && GemmConfig::APreshuffleQuant == true,
|
||||
ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>,
|
||||
std::conditional_t<
|
||||
QuantMode == ck_tile::QuantType::AQuantGrouped,
|
||||
ck_tile::BaseGemmPipelineAgBgCrMem<GemmPipelineProblem>,
|
||||
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<GemmPipelineProblem>>>>;
|
||||
|
||||
const ck_tile::index_t K_split =
|
||||
(args.K + GemmConfig::K_Tile - 1) / GemmConfig::K_Tile * GemmConfig::K_Tile;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::index_t K_split = ck_tile::integer_least_multiple(args.K, GemmConfig::K_Tile);
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr bool transpose_c = false;
|
||||
|
||||
// row-col and tensor quants use the regular pipeline, A/B/AB quants use their own
|
||||
using PipelineProblem = std::conditional_t<
|
||||
@@ -147,7 +148,7 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
|
||||
has_hot_loop_v,
|
||||
tail_number_v>>>>;
|
||||
using AQuantPipeline =
|
||||
std::conditional_t<GemmConfig::PreshuffleQuant,
|
||||
std::conditional_t<GemmConfig::APreshuffleQuant,
|
||||
ck_tile::AQuantGemmPipelineAgBgCrCompV3<PipelineProblem>,
|
||||
ck_tile::AQuantGemmPipelineAgBgCrMem<PipelineProblem>>;
|
||||
|
||||
@@ -391,8 +392,10 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
std::cout << " Acc_Type = " << ck_tile::DataTypeTraits<typename TypeConfig::AccDataType>::name
|
||||
<< " C_Type = " << ck_tile::DataTypeTraits<typename TypeConfig::CDataType>::name
|
||||
<< " QuantMode = " << quant_type_to_string(QuantMode)
|
||||
<< " PreshuffleQuant = " << (GemmConfig::PreshuffleQuant ? "true" : "false") << " : "
|
||||
<< " PreshuffleB = " << (GemmConfig::PreshuffleB ? "true" : "false") << " : "
|
||||
<< " APreshuffleQuant = " << (GemmConfig::APreshuffleQuant ? "true" : "false")
|
||||
<< " : "
|
||||
<< " BPreshuffleQuant = " << (GemmConfig::BPreshuffleQuant ? "true" : "false")
|
||||
<< " : " << " PreshuffleB = " << (GemmConfig::PreshuffleB ? "true" : "false") << " : "
|
||||
<< ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
|
||||
<< std::endl;
|
||||
|
||||
@@ -537,24 +540,15 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
|
||||
// Create BQ tensor with appropriate shape
|
||||
std::unique_ptr<ck_tile::HostTensor<BQDataType>> bq_tensor_ptr = nullptr;
|
||||
if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped ||
|
||||
QuantMode == ck_tile::QuantType::RowColQuant)
|
||||
QuantMode == ck_tile::QuantType::ABQuantGrouped ||
|
||||
QuantMode == ck_tile::QuantType::RowColQuant ||
|
||||
QuantMode == ck_tile::QuantType::TensorQuant)
|
||||
{
|
||||
bq_tensor_ptr = std::make_unique<ck_tile::HostTensor<BQDataType>>(
|
||||
ck_tile::host_tensor_descriptor(BQK, BQN, stride_BQ, is_row_major(bq_layout)));
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped)
|
||||
{
|
||||
bq_tensor_ptr = std::make_unique<ck_tile::HostTensor<BQDataType>>(
|
||||
ck_tile::host_tensor_descriptor(BQK, BQN, stride_BQ, is_row_major(bq_layout)));
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::TensorQuant)
|
||||
{
|
||||
bq_tensor_ptr = std::make_unique<ck_tile::HostTensor<BQDataType>>(
|
||||
ck_tile::host_tensor_descriptor(1, 1, stride_BQ, is_row_major(bq_layout)));
|
||||
}
|
||||
|
||||
std::random_device rd;
|
||||
std::mt19937 gen(rd());
|
||||
std::mt19937 gen(42);
|
||||
std::uniform_int_distribution<std::uint32_t> fill_seed(0, 500);
|
||||
|
||||
if(init_method == 0)
|
||||
@@ -630,7 +624,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
|
||||
else if(init_method == 1)
|
||||
{
|
||||
std::cout << "Monotonic initialization is not supported." << std::endl;
|
||||
return 0;
|
||||
return -1;
|
||||
}
|
||||
else if(init_method == 2)
|
||||
{
|
||||
@@ -650,7 +644,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
|
||||
else
|
||||
{
|
||||
ck_tile::FillConstant<ADataType>{static_cast<ADataType>(0x22)}(a_m_k);
|
||||
ck_tile::FillConstant<AQDataType>{static_cast<AQDataType>(0.5f)}(*aq_tensor_ptr);
|
||||
ck_tile::FillConstant<AQDataType>{static_cast<AQDataType>(1.0f)}(*aq_tensor_ptr);
|
||||
ck_tile::FillConstant<BDataType>{static_cast<BDataType>(0x38)}(b_k_n);
|
||||
|
||||
if constexpr(QuantMode == ck_tile::QuantType::RowColQuant)
|
||||
@@ -659,6 +653,184 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
|
||||
}
|
||||
}
|
||||
}
|
||||
else if(init_method == 3)
|
||||
{
|
||||
if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
|
||||
{
|
||||
ck_tile::FillConstant<ADataType>{static_cast<ADataType>(0x38)}(a_m_k);
|
||||
ck_tile::FillConstant<BDataType>{static_cast<BDataType>(0x22)}(b_k_n);
|
||||
ck_tile::FillConstant<BQDataType>{static_cast<BQDataType>(0.5f)}(*bq_tensor_ptr);
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped)
|
||||
{
|
||||
ck_tile::FillConstant<ADataType>{static_cast<ADataType>(0x38)}(a_m_k);
|
||||
ck_tile::FillConstant<BDataType>{static_cast<BDataType>(0x22)}(b_k_n);
|
||||
ck_tile::FillConstant<AQDataType>{static_cast<AQDataType>(0.5f)}(*aq_tensor_ptr);
|
||||
ck_tile::FillConstant<BQDataType>{static_cast<BQDataType>(0.5f)}(*bq_tensor_ptr);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::FillConstant<ADataType>{static_cast<ADataType>(0x22)}(a_m_k);
|
||||
ck_tile::FillConstant<AQDataType>{static_cast<AQDataType>(2.0f)}(*aq_tensor_ptr);
|
||||
ck_tile::FillConstant<BDataType>{static_cast<BDataType>(0x38)}(b_k_n);
|
||||
|
||||
if constexpr(QuantMode == ck_tile::QuantType::RowColQuant)
|
||||
{
|
||||
ck_tile::FillConstant<BQDataType>{static_cast<BQDataType>(0.5f)}(*bq_tensor_ptr);
|
||||
}
|
||||
}
|
||||
}
|
||||
else if(init_method == 4)
|
||||
{
|
||||
if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
|
||||
{
|
||||
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ck_tile::pk_int4_t>{-5.0f, 5.0f, fill_seed(gen)}(
|
||||
b_k_n);
|
||||
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
|
||||
*bq_tensor_ptr);
|
||||
}
|
||||
else if constexpr(std::is_same_v<BDataType, ck_tile::pk_fp4_raw_t>)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<BDataType>{-5.0f, 5.0f, fill_seed(gen)}(b_k_n);
|
||||
ck_tile::FillUniformDistribution<BQDataType>{125.f, 130.f, fill_seed(gen)}(
|
||||
*bq_tensor_ptr);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::FillUniformDistribution<BDataType>{-2.0f, 3.0f, fill_seed(gen)}(b_k_n);
|
||||
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
|
||||
*bq_tensor_ptr);
|
||||
}
|
||||
|
||||
ck_tile::FillUniformDistribution<ADataType>{-5.0f, 5.0f, fill_seed(gen)}(a_m_k);
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped)
|
||||
{
|
||||
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ck_tile::pk_int4_t>{-5.0f, 5.0f, fill_seed(gen)}(
|
||||
a_m_k);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{-2.0f, 3.0f, fill_seed(gen)}(a_m_k);
|
||||
}
|
||||
ck_tile::FillUniformDistribution<AQDataType>{2.0f, 2.0f, fill_seed(gen)}(
|
||||
*aq_tensor_ptr);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-5.0f, 5.0f, fill_seed(gen)}(b_k_n);
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped)
|
||||
{
|
||||
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ck_tile::pk_int4_t>{-5.0f, 5.0f, fill_seed(gen)}(
|
||||
a_m_k);
|
||||
ck_tile::FillUniformDistribution<ck_tile::pk_int4_t>{-5.0f, 5.0f, fill_seed(gen)}(
|
||||
b_k_n);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{-2.0f, 3.0f, fill_seed(gen)}(a_m_k);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-2.0f, 3.0f, fill_seed(gen)}(b_k_n);
|
||||
}
|
||||
ck_tile::FillUniformDistribution<AQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
|
||||
*aq_tensor_ptr);
|
||||
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
|
||||
*bq_tensor_ptr);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{-2.0f, 2.0f, fill_seed(gen)}(a_m_k);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-2.0f, 2.0f, fill_seed(gen)}(b_k_n);
|
||||
ck_tile::FillUniformDistribution<AQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
|
||||
*aq_tensor_ptr);
|
||||
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
|
||||
*bq_tensor_ptr);
|
||||
}
|
||||
}
|
||||
else if(init_method == 5)
|
||||
{
|
||||
if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
|
||||
{
|
||||
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ck_tile::pk_int4_t>{-5.0f, 5.0f, fill_seed(gen)}(
|
||||
b_k_n);
|
||||
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
|
||||
*bq_tensor_ptr);
|
||||
}
|
||||
else if constexpr(std::is_same_v<BDataType, ck_tile::pk_fp4_raw_t>)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<BDataType>{-5.0f, 5.0f, fill_seed(gen)}(b_k_n);
|
||||
ck_tile::FillUniformDistribution<BQDataType>{125.f, 130.f, fill_seed(gen)}(
|
||||
*bq_tensor_ptr);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::FillUniformDistribution<BDataType>{-2.0f, 3.0f, fill_seed(gen)}(b_k_n);
|
||||
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
|
||||
*bq_tensor_ptr);
|
||||
}
|
||||
|
||||
ck_tile::FillUniformDistribution<ADataType>{-5.0f, 5.0f, fill_seed(gen)}(a_m_k);
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped)
|
||||
{
|
||||
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ck_tile::pk_int4_t>{-5.0f, 5.0f, fill_seed(gen)}(
|
||||
a_m_k);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{1.0f, 1.0f, fill_seed(gen)}(a_m_k);
|
||||
}
|
||||
// Fill aquant such that column j has value 2^j (1, 2, 4, 8, ...)
|
||||
for(ck_tile::index_t row = 0;
|
||||
row < static_cast<ck_tile::index_t>(aq_tensor_ptr->get_length(0));
|
||||
++row)
|
||||
{
|
||||
for(ck_tile::index_t col = 0;
|
||||
col < static_cast<ck_tile::index_t>(aq_tensor_ptr->get_length(1));
|
||||
++col)
|
||||
{
|
||||
(*aq_tensor_ptr)(row, col) = static_cast<AQDataType>(col + 1);
|
||||
}
|
||||
}
|
||||
// std::cout << "aq_tensor_ptr: " << *aq_tensor_ptr << std::endl;
|
||||
ck_tile::FillUniformDistribution<BDataType>{1.0f, 1.0f, fill_seed(gen)}(b_k_n);
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped)
|
||||
{
|
||||
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ck_tile::pk_int4_t>{-5.0f, 5.0f, fill_seed(gen)}(
|
||||
a_m_k);
|
||||
ck_tile::FillUniformDistribution<ck_tile::pk_int4_t>{-5.0f, 5.0f, fill_seed(gen)}(
|
||||
b_k_n);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{-2.0f, 3.0f, fill_seed(gen)}(a_m_k);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-2.0f, 3.0f, fill_seed(gen)}(b_k_n);
|
||||
}
|
||||
ck_tile::FillUniformDistribution<AQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
|
||||
*aq_tensor_ptr);
|
||||
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
|
||||
*bq_tensor_ptr);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{-2.0f, 2.0f, fill_seed(gen)}(a_m_k);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-2.0f, 2.0f, fill_seed(gen)}(b_k_n);
|
||||
ck_tile::FillUniformDistribution<AQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
|
||||
*aq_tensor_ptr);
|
||||
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
|
||||
*bq_tensor_ptr);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
a_m_k.SetZero();
|
||||
@@ -694,7 +866,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
|
||||
QuantMode == ck_tile::QuantType::RowColQuant ||
|
||||
QuantMode == ck_tile::QuantType::TensorQuant)
|
||||
{
|
||||
if constexpr(GemmConfig::PreshuffleQuant)
|
||||
if constexpr(GemmConfig::APreshuffleQuant)
|
||||
{
|
||||
ck_tile::HostTensor<AQDataType> aq_shuffle_host =
|
||||
ck_tile::shuffle_aq(aq_tensor_ptr.get(), GemmConfig::K_Tile / AQuantGroupSize::kK);
|
||||
@@ -753,7 +925,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
|
||||
ck_tile::HostTensor<BQDataType> bq_permuted_host =
|
||||
ck_tile::bq_permuteN<GemmConfig>(*bq_tensor_ptr, BQuantGroupSize::kN);
|
||||
|
||||
if constexpr(GemmConfig::PreshuffleQuant)
|
||||
if constexpr(GemmConfig::BPreshuffleQuant)
|
||||
{
|
||||
ck_tile::HostTensor<BQDataType> bq_shuffle_host = ck_tile::shuffle_bq(
|
||||
&bq_permuted_host, GemmConfig::K_Tile / BQuantGroupSize::kK);
|
||||
@@ -764,7 +936,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
|
||||
bq_dev_buf_ptr->ToDevice(bq_permuted_host.data());
|
||||
}
|
||||
}
|
||||
else if constexpr(GemmConfig::PreshuffleQuant)
|
||||
else if constexpr(GemmConfig::BPreshuffleQuant)
|
||||
{
|
||||
ck_tile::HostTensor<BQDataType> bq_shuffle_host =
|
||||
ck_tile::shuffle_bq(bq_tensor_ptr.get(), GemmConfig::K_Tile / BQuantGroupSize::kK);
|
||||
@@ -900,10 +1072,10 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
|
||||
else if(arg_parser.get_int("v") == 2)
|
||||
{
|
||||
std::cout << "GPU verification is not implemented yet. Re-run with -v=1" << std::endl;
|
||||
return false;
|
||||
return -1;
|
||||
}
|
||||
|
||||
return pass;
|
||||
return pass ? 0 : -1;
|
||||
}
|
||||
// Usage of Two-Matrix Quantization (AB-Quant)
|
||||
template <typename GemmConfig,
|
||||
@@ -945,7 +1117,7 @@ int run_gemm_example_prec_type(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
if constexpr((QuantMode == ck_tile::QuantType::AQuantGrouped ||
|
||||
QuantMode == ck_tile::QuantType::ABQuantGrouped) &&
|
||||
!GemmConfig::PreshuffleQuant && !GemmConfig::PreshuffleB)
|
||||
!GemmConfig::APreshuffleQuant && !GemmConfig::PreshuffleB)
|
||||
{
|
||||
if(a_layout == "R" && b_layout == "R")
|
||||
{
|
||||
@@ -966,7 +1138,8 @@ int run_gemm_example_prec_type(const ck_tile::ArgParser& arg_parser)
|
||||
arg_parser, Col{}, Row{}, Row{}, Col{}, Row{});
|
||||
}
|
||||
}
|
||||
if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped && !GemmConfig::PreshuffleQuant)
|
||||
if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped &&
|
||||
!GemmConfig::APreshuffleQuant)
|
||||
{
|
||||
if(a_layout == "C" && b_layout == "C")
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user