merge upstream

This commit is contained in:
Yanxing-Shi
2025-05-27 09:31:20 +00:00
19 changed files with 350 additions and 237 deletions

View File

@@ -18,7 +18,10 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
* Added support for Split K for grouped convolution backward data.
* Added logit soft-capping support for fMHA forward kernels.
* Added benchmarking support for tile engine GEMM.
<<<<<<< HEAD
* Added profiling cache support for tile engine GEMM.
=======
>>>>>>> upstream/develop
### Optimized

View File

@@ -58,7 +58,8 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
{F_lse},
{F_dropout},
{F_squant},
{F_occupancy}>;
{F_occupancy},
{F_skip}>;
using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>;
@@ -94,7 +95,7 @@ using fmha_kernel_{F_idx} =
ck_tile::FmhaFwdKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>;
using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>;
#include <iostream>
@@ -129,9 +130,9 @@ FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <
}}
"""
FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) &&
FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>;
return fmha_fwd_<trait_>(s, a);
}}
"""
@@ -160,11 +161,12 @@ class FmhaFwdApiTrait:
skpad : str
dpad : str
dvpad : str
skip : str
@property
def name(self) -> str:
return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-'+\
f'{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}'
f'{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}'
@property
def scheck(self) -> str:
@@ -227,6 +229,7 @@ class FmhaFwdPipeline:
F_dropout : str #
F_squant : str #
F_mask : str # value from MASK_MAP
F_skip : str # true/false
@property
def name(self) -> str:
@@ -262,8 +265,12 @@ class FmhaFwdPipeline:
if self.F_dropout == 't' : n += '_dropout'
else: n += '_ndropout'
if self.F_skip == 't' : n += '_skip'
else: n += '_nskip'
if self.F_squant == 't' : n += '_squant'
else: n += '_nsquant'
return n
class FmhaFwdApiPool:
@@ -293,7 +300,7 @@ class FmhaFwdApiPool:
inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout],
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask],
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias],
F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout] ,
F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout], F_skip=BOOL_MAP[trait.skip],
F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max,
@@ -381,6 +388,7 @@ class FmhaFwdKernel:
F_lse = BOOL_MAP[self.F_pipeline.F_lse],
F_dropout = BOOL_MAP[self.F_pipeline.F_dropout],
F_squant = BOOL_MAP[self.F_pipeline.F_squant],
F_skip = BOOL_MAP[self.F_pipeline.F_skip],
F_occupancy = self.F_tile.F_occupancy,
F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag],
F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
@@ -419,7 +427,8 @@ class FmhaFwdKernel:
spad=self.F_pipeline.F_spad,
skpad=self.F_pipeline.F_skpad,
dpad=self.F_pipeline.F_dpad,
dvpad=self.F_pipeline.F_dvpad)
dvpad=self.F_pipeline.F_dvpad,
skip=self.F_pipeline.F_skip)
# TODO: design a more practical way to do it
# this is current supported tile size per hdim
@@ -453,36 +462,36 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
squant = 't' if dtype == 'fp8' else 'f'
pipelines = []
if dtype in ['fp16', 'bf16']:
for logits, mask, bias, lse, dropout in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]):
for logits, mask, bias, lse, dropout, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"]):
if hdim == 256:
# if True:
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip))
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip))
# the below two is used for hdim vectorize load
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip))
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip))
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip))
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip))
else:
if bias == "bias":
# TODO: rocm 6.2 compiler problem if using qr_async for bias case
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip))
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip))
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip))
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip))
else:
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip))
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip))
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip))
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip))
if receipt == 1 and bias != "bias":
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) # TODO: cover arbitraty hdim
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) # TODO: cover arbitraty hdim
elif dtype in ['fp8', 'bf8']:
# no need lse/dropout kernels
for logits, mask, bias in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, 'f', 'f', squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, 'f', 'f', squant, mask, 'f'))
elif dtype in ['fp8fp16', 'fp8bf16']:
# TODO
None
@@ -532,6 +541,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
cond &= pipeline.F_vlayout == 'row'
cond &= pipeline.F_bias in ['no', 'alibi']
cond &= pipeline.F_squant == 'f'
cond &= pipeline.F_skip == 'f'
if not cond:
continue
# PyTorch integration
@@ -540,6 +550,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
cond &= pipeline.F_vlayout == 'row'
cond &= pipeline.F_bias in ['no', 'bias']
cond &= pipeline.F_squant == 'f'
cond &= pipeline.F_skip == 'f'
if not cond:
continue
# Aiter(mha_fwd) integration
@@ -565,6 +576,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
cond &= pipeline.F_squant == 'f'
if not cond:
continue
api_pool.register_traits(k.api_trait())
gen.append(k)

View File

@@ -169,6 +169,7 @@ struct fmha_fwd_args
ck_tile::index_t window_size_left;
ck_tile::index_t window_size_right;
ck_tile::index_t mask_type;
ck_tile::index_t min_seqlen_q;
float p_drop;
bool s_randval;
@@ -433,6 +434,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args.window_size_left,
args.window_size_right,
args.mask_type,
args.min_seqlen_q,
args.p_drop,
args.s_randval,
args.drop_seed_offset);
@@ -837,7 +839,8 @@ template <ck_tile::index_t HDim_,
bool kPadS_,
bool kPadSK_,
bool kPadD_,
bool kPadDv_>
bool kPadDv_,
bool kSkipMinSeqlenQ_ = false>
struct fmha_fwd_traits_
{
static constexpr ck_tile::index_t HDim = HDim_;
@@ -861,6 +864,7 @@ struct fmha_fwd_traits_
static constexpr bool kPadSK = kPadSK_;
static constexpr bool kPadD = kPadD_;
static constexpr bool kPadDv = kPadDv_;
static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_;
};
template <typename Traits_>
@@ -995,6 +999,7 @@ struct fmha_fwd_traits
bool has_lse;
bool has_dropout;
bool do_fp8_static_quant;
bool skip_min_seqlen_q = false;
// TODO: padding check is inside this api
};
float fmha_fwd(fmha_fwd_traits, fmha_fwd_args, const ck_tile::stream_config&);

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -393,8 +393,10 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
{
const index_t GemmM = K;
const index_t GemmN = C * X;
const auto PadGemmM = MPerBlock - GemmM % MPerBlock;
const auto PadGemmN = NPerBlock - GemmN % NPerBlock;
const auto PadGemmM =
GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock;
const auto PadGemmN =
GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock;
return transform_tensor_descriptor(
wei_grid_desc,
@@ -432,8 +434,10 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
{
const index_t GemmM = K;
const index_t GemmN = C * X * Y;
const auto PadGemmM = MPerBlock - GemmM % MPerBlock;
const auto PadGemmN = NPerBlock - GemmN % NPerBlock;
const auto PadGemmM =
GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock;
const auto PadGemmN =
GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock;
return transform_tensor_descriptor(
wei_grid_desc,
@@ -472,8 +476,10 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
{
const index_t GemmM = K;
const index_t GemmN = C * X * Y * Z;
const auto PadGemmM = MPerBlock - GemmM % MPerBlock;
const auto PadGemmN = NPerBlock - GemmN % NPerBlock;
const auto PadGemmM =
GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock;
const auto PadGemmN =
GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock;
return transform_tensor_descriptor(
wei_grid_desc,

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -208,8 +208,8 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
const index_t GemmM = K;
const index_t GemmN = C * Z * X * Y;
const auto PadGemmM = MPerBlock - GemmM % MPerBlock;
const auto PadGemmN = NPerBlock - GemmN % NPerBlock;
const auto PadGemmM = GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock;
const auto PadGemmN = GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock;
const index_t GemmK0 =
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock) * K0PerBlock;

View File

@@ -1,6 +1,6 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -166,8 +166,8 @@ struct TransformConvBwdWeightToGemm
const index_t GemmM = K;
const index_t GemmN = C * X;
const auto PadGemmM = MPerBlock - GemmM % MPerBlock;
const auto PadGemmN = NPerBlock - GemmN % NPerBlock;
const auto PadGemmM = GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock;
const auto PadGemmN = GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock;
const index_t GemmKBatch = batch_k;
const index_t GemmK0 =
@@ -365,8 +365,8 @@ struct TransformConvBwdWeightToGemm
const index_t GemmM = K;
const index_t GemmN = C * X * Y;
const auto PadGemmM = MPerBlock - GemmM % MPerBlock;
const auto PadGemmN = NPerBlock - GemmN % NPerBlock;
const auto PadGemmM = GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock;
const auto PadGemmN = GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock;
const index_t GemmKBatch = batch_k;
const index_t GemmK0 =
@@ -558,8 +558,8 @@ struct TransformConvBwdWeightToGemm
const index_t GemmM = K;
const index_t GemmN = C * Z * X * Y;
const auto PadGemmM = MPerBlock - GemmM % MPerBlock;
const auto PadGemmN = NPerBlock - GemmN % NPerBlock;
const auto PadGemmM = GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock;
const auto PadGemmN = GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock;
const index_t GemmKBatch = batch_k;
const index_t GemmK0 =

View File

@@ -346,8 +346,8 @@ struct TransformConvBwdWeightToGemmV2
const index_t GemmM = K * NumGroupsToMerge;
const index_t GemmN = C * X * NumGroupsToMerge;
const auto PadGemmM = MPerBlock - GemmM % MPerBlock;
const auto PadGemmN = NPerBlock - GemmN % NPerBlock;
const auto PadGemmM = GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock;
const auto PadGemmN = GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock;
const index_t GemmKBatch = batch_k;
const index_t GemmK0 =
@@ -534,8 +534,8 @@ struct TransformConvBwdWeightToGemmV2
const index_t GemmM = K * NumGroupsToMerge;
const index_t GemmN = C * X * Y * NumGroupsToMerge;
const auto PadGemmM = MPerBlock - GemmM % MPerBlock;
const auto PadGemmN = NPerBlock - GemmN % NPerBlock;
const auto PadGemmM = GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock;
const auto PadGemmN = GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock;
const index_t GemmKBatch = batch_k;
const index_t GemmK0 =
@@ -737,8 +737,8 @@ struct TransformConvBwdWeightToGemmV2
const index_t GemmM = K * NumGroupsToMerge;
const index_t GemmN = C * Z * X * Y * NumGroupsToMerge;
const auto PadGemmM = MPerBlock - GemmM % MPerBlock;
const auto PadGemmN = NPerBlock - GemmN % NPerBlock;
const auto PadGemmM = GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock;
const auto PadGemmN = GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock;
const index_t GemmKBatch = batch_k;
const index_t GemmK0 =

View File

@@ -55,8 +55,8 @@
#include "ck_tile/core/tensor/tile_distribution_encoding.hpp"
#include "ck_tile/core/tensor/tile_elementwise.hpp"
#include "ck_tile/core/tensor/tile_scatter_gather.hpp"
#include "ck_tile/core/tensor/tile_window_base.hpp"
#include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/tensor/tile_window_base.hpp"
#include "ck_tile/core/tensor/tile_window_linear.hpp"
#include "ck_tile/core/tensor/tile_window_utils.hpp"
#include "ck_tile/core/tensor/transpose_tile.hpp"

View File

@@ -35,4 +35,5 @@
#include "ck_tile/host/reference/reference_softmax.hpp"
#include "ck_tile/host/reference/reference_topk.hpp"
#include "ck_tile/host/stream_config.hpp"
#include "ck_tile/host/stream_utils.hpp"
#include "ck_tile/host/timer.hpp"

View File

@@ -53,6 +53,8 @@ struct FmhaFwdKernel
static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
static constexpr bool kHasDropout = FmhaPipeline::kHasDropout;
static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
static constexpr bool kSkipMinSeqlenQ = FmhaPipeline::Problem::kSkipMinSeqlenQ;
using AttentionVariant = ck_tile::remove_cvref_t<typename FmhaPipeline::AttentionVariant>;
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
static constexpr bool kHasMask = FmhaMask::IsMasking;
@@ -257,6 +259,11 @@ struct FmhaFwdKernel
ck_tile::index_t batch_stride_randval = 0;
};
struct FmhaFwdSkipMinSeqlenQKargs
{
ck_tile::index_t min_seqlen_q = 0;
};
struct FmhaFwdBatchModeKargs
: FmhaFwdCommonKargs,
std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
@@ -287,7 +294,8 @@ struct FmhaFwdKernel
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
std::conditional_t<kDoFp8StaticQuant, FmhaFwdFp8StaticQuantKargs, FmhaFwdEmptyKargs<3>>,
std::conditional_t<kHasDropout, FmhaFwdCommonDropoutKargs, FmhaFwdEmptyKargs<4>>,
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>,
std::conditional_t<kSkipMinSeqlenQ, FmhaFwdSkipMinSeqlenQKargs, FmhaFwdEmptyKargs<6>>
{
const int32_t* seqstart_q_ptr;
const int32_t* seqstart_k_ptr;
@@ -664,6 +672,7 @@ struct FmhaFwdKernel
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
ck_tile::index_t min_seqlen_q,
float p_drop,
bool s_randval,
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
@@ -698,6 +707,7 @@ struct FmhaFwdKernel
{}, // placeholder for fp8_static_quant args
{}, // placeholder for dropout
{}, // placeholder for logits_soft_cap
{}, // placeholder for min_seqlen_q
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
reinterpret_cast<const int32_t*>(seqstart_k_ptr),
reinterpret_cast<const int32_t*>(seqlen_k_ptr)};
@@ -753,6 +763,10 @@ struct FmhaFwdKernel
{
kargs.init_logits_soft_cap(logits_soft_cap);
}
if constexpr(kSkipMinSeqlenQ)
{
kargs.min_seqlen_q = min_seqlen_q;
}
return kargs;
}
@@ -969,7 +983,15 @@ struct FmhaFwdKernel
const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
if constexpr(kHasMask)
{
// assume that num_tile_n1 is always 1
return ck_tile::make_tuple(gridDim.z - 1 - i_tile_m, i_tile_n, i_nhead, i_batch);
}
else
{
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
}
}
else
{
@@ -989,7 +1011,15 @@ struct FmhaFwdKernel
const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
if constexpr(kHasMask)
{
// assume that num_tile_n1 is always 1
return ck_tile::make_tuple(gridDim.x - 1 - i_tile_m, i_tile_n, i_nhead, i_batch);
}
else
{
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
}
}
}
@@ -1053,6 +1083,14 @@ struct FmhaFwdKernel
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
if constexpr(kSkipMinSeqlenQ)
{
if(kargs.seqlen_q <= kargs.min_seqlen_q)
{
return;
}
}
// # of required blocks is different in each groups, terminate unnecessary blocks
// earlier
if(kargs.seqlen_q <= i_m0)

View File

@@ -561,7 +561,16 @@ struct FmhaFwdSplitKVKernel
const index_t i_nhead = blockIdx.y;
const index_t i_batch = blockIdx.z;
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_split, i_nhead, i_batch);
if constexpr(kHasMask)
{
// assume that num_tile_n1 is always 1
return ck_tile::make_tuple(
(gridDim.x / kargs.num_splits) - 1 - i_tile_m, i_tile_n, i_split, i_nhead, i_batch);
}
else
{
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_split, i_nhead, i_batch);
}
}
__host__ static constexpr auto BlockSize() { return dim3(kBlockSize); }

View File

@@ -53,6 +53,7 @@ struct BlockFmhaPipelineProblem
static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
static constexpr bool kHasLogitsSoftCap = Traits::kHasLogitsSoftCap;
static constexpr bool kSkipMinSeqlenQ = Traits::kSkipMinSeqlenQ;
static constexpr auto BiasEnum = Traits::BiasEnum;
static constexpr bool kStoreLSE = Traits::kStoreLSE;
static constexpr bool kHasDropout = Traits::kHasDropout;

View File

@@ -19,7 +19,8 @@ template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
bool kStoreLSE_,
bool kHasDropout_,
bool kDoFp8StaticQuant_,
index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
index_t kBlockPerCu_ = -1, /* overwrite occupancy if not -1 */
bool kSkipMinSeqlenQ_ = false /* skip min seqlen q while chunked prefill */>
struct TileFmhaTraits
{
static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
@@ -33,6 +34,7 @@ struct TileFmhaTraits
static constexpr bool kHasDropout = kHasDropout_;
static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
static constexpr index_t kBlockPerCu = kBlockPerCu_;
static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_;
};
template <bool kPadSeqLenQ_ /* padding for seqlen_q */,

View File

@@ -1,6 +1,16 @@
# ckProfiler
set(PROFILER_SOURCES
profiler.cpp
set(CK_PROFILER_OP_FILTER "" CACHE STRING "Filter for the operators to be profiled. Default is to include all")
set(CK_PROFILER_INSTANCE_FILTER "" CACHE STRING "Filter for the kernels instances to be profiled. Default is to be the same as the operator filter")
if (CK_PROFILER_OP_FILTER STREQUAL "")
set(CK_PROFILER_OP_FILTER ".+")
endif()
if (CK_PROFILER_INSTANCE_FILTER STREQUAL "")
set(CK_PROFILER_INSTANCE_FILTER ${CK_PROFILER_OP_FILTER})
endif()
message(STATUS "CK_PROFILER_OP_FILTER: ${CK_PROFILER_OP_FILTER}")
message(STATUS "CK_PROFILER_INSTANCE_FILTER: ${CK_PROFILER_INSTANCE_FILTER}")
set(PROFILER_OPS
profile_gemm.cpp
profile_reduce.cpp
profile_groupnorm_bwd_data.cpp
@@ -26,161 +36,188 @@ set(PROFILER_SOURCES
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9")
if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES)
list(APPEND PROFILER_SOURCES profile_contraction_bilinear.cpp)
list(APPEND PROFILER_SOURCES profile_contraction_scale.cpp)
list(APPEND PROFILER_OPS profile_contraction_bilinear.cpp)
list(APPEND PROFILER_OPS profile_contraction_scale.cpp)
endif()
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list(APPEND PROFILER_SOURCES profile_gemm_reduce.cpp)
list(APPEND PROFILER_SOURCES profile_batched_gemm_gemm.cpp)
list(APPEND PROFILER_SOURCES profile_batched_gemm_add_relu_gemm_add.cpp)
list(APPEND PROFILER_SOURCES profile_gemm_add.cpp)
list(APPEND PROFILER_SOURCES profile_gemm_add_add_fastgelu.cpp)
list(APPEND PROFILER_SOURCES profile_gemm_add_fastgelu.cpp)
list(APPEND PROFILER_SOURCES profile_grouped_gemm.cpp)
list(APPEND PROFILER_SOURCES profile_gemm_streamk.cpp)
list(APPEND PROFILER_SOURCES profile_gemm_fastgelu.cpp)
list(APPEND PROFILER_SOURCES profile_gemm_add_relu.cpp)
list(APPEND PROFILER_SOURCES profile_gemm_add_silu.cpp)
list(APPEND PROFILER_SOURCES profile_gemm_add_relu_add_layernorm.cpp)
list(APPEND PROFILER_SOURCES profile_grouped_gemm_fixed_nk.cpp)
list(APPEND PROFILER_SOURCES profile_grouped_gemm_fastgelu.cpp)
list(APPEND PROFILER_SOURCES profile_grouped_gemm_tile_loop.cpp)
list(APPEND PROFILER_SOURCES profile_grouped_gemm_multiply_tile_loop.cpp)
list(APPEND PROFILER_OPS profile_gemm_reduce.cpp)
list(APPEND PROFILER_OPS profile_batched_gemm_gemm.cpp)
list(APPEND PROFILER_OPS profile_batched_gemm_add_relu_gemm_add.cpp)
list(APPEND PROFILER_OPS profile_gemm_add.cpp)
list(APPEND PROFILER_OPS profile_gemm_add_add_fastgelu.cpp)
list(APPEND PROFILER_OPS profile_gemm_add_fastgelu.cpp)
list(APPEND PROFILER_OPS profile_grouped_gemm.cpp)
list(APPEND PROFILER_OPS profile_gemm_streamk.cpp)
list(APPEND PROFILER_OPS profile_gemm_fastgelu.cpp)
list(APPEND PROFILER_OPS profile_gemm_add_relu.cpp)
list(APPEND PROFILER_OPS profile_gemm_add_silu.cpp)
list(APPEND PROFILER_OPS profile_gemm_add_relu_add_layernorm.cpp)
list(APPEND PROFILER_OPS profile_grouped_gemm_fixed_nk.cpp)
list(APPEND PROFILER_OPS profile_grouped_gemm_fastgelu.cpp)
list(APPEND PROFILER_OPS profile_grouped_gemm_tile_loop.cpp)
list(APPEND PROFILER_OPS profile_grouped_gemm_multiply_tile_loop.cpp)
endif()
list(APPEND PROFILER_SOURCES profile_gemm_multiply_add.cpp)
if(SUPPORTED_GPU_TARGETS MATCHES "gfx94")
list(APPEND PROFILER_SOURCES profile_gemm_multiply_multiply.cpp)
list(APPEND PROFILER_SOURCES profile_gemm_multiply_multiply_wp.cpp)
list(APPEND PROFILER_SOURCES profile_gemm_ab_scale.cpp)
list(APPEND PROFILER_OPS profile_gemm_multiply_add.cpp)
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9[45]")
list(APPEND PROFILER_OPS profile_gemm_multiply_multiply.cpp)
list(APPEND PROFILER_OPS profile_gemm_multiply_multiply_wp.cpp)
list(APPEND PROFILER_OPS profile_gemm_ab_scale.cpp)
endif()
list(APPEND PROFILER_SOURCES profile_batched_gemm.cpp)
list(APPEND PROFILER_SOURCES profile_batched_gemm_reduce.cpp)
list(APPEND PROFILER_SOURCES profile_gemm_add_multiply.cpp)
list(APPEND PROFILER_SOURCES profile_gemm_bias_add_reduce.cpp)
list(APPEND PROFILER_SOURCES profile_gemm_splitk.cpp)
list(APPEND PROFILER_SOURCES profile_gemm_b_scale.cpp)
list(APPEND PROFILER_SOURCES profile_batched_gemm_b_scale.cpp)
list(APPEND PROFILER_SOURCES profile_gemm_universal_batched.cpp)
list(APPEND PROFILER_SOURCES profile_gemm_universal_reduce.cpp)
list(APPEND PROFILER_SOURCES profile_gemm_universal_streamk.cpp)
list(APPEND PROFILER_SOURCES profile_conv_fwd_bias_relu.cpp)
list(APPEND PROFILER_SOURCES profile_conv_fwd_bias_relu_add.cpp)
list(APPEND PROFILER_SOURCES profile_conv_bwd_data.cpp)
list(APPEND PROFILER_SOURCES profile_conv_fwd.cpp)
list(APPEND PROFILER_SOURCES profile_grouped_conv_fwd_outelementop.cpp)
list(APPEND PROFILER_OPS profile_batched_gemm.cpp)
list(APPEND PROFILER_OPS profile_batched_gemm_reduce.cpp)
list(APPEND PROFILER_OPS profile_gemm_add_multiply.cpp)
list(APPEND PROFILER_OPS profile_gemm_bias_add_reduce.cpp)
list(APPEND PROFILER_OPS profile_gemm_splitk.cpp)
list(APPEND PROFILER_OPS profile_gemm_b_scale.cpp)
list(APPEND PROFILER_OPS profile_batched_gemm_b_scale.cpp)
list(APPEND PROFILER_OPS profile_gemm_universal_batched.cpp)
list(APPEND PROFILER_OPS profile_gemm_universal_reduce.cpp)
list(APPEND PROFILER_OPS profile_gemm_universal_streamk.cpp)
list(APPEND PROFILER_OPS profile_conv_fwd_bias_relu.cpp)
list(APPEND PROFILER_OPS profile_conv_fwd_bias_relu_add.cpp)
list(APPEND PROFILER_OPS profile_conv_bwd_data.cpp)
list(APPEND PROFILER_OPS profile_conv_fwd.cpp)
list(APPEND PROFILER_OPS profile_grouped_conv_fwd_outelementop.cpp)
endif()
if(SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12" OR SUPPORTED_GPU_TARGETS MATCHES "gfx9")
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list(APPEND PROFILER_SOURCES profile_gemm_bilinear.cpp)
list(APPEND PROFILER_OPS profile_gemm_bilinear.cpp)
endif()
list(APPEND PROFILER_SOURCES profile_gemm_universal.cpp)
list(APPEND PROFILER_SOURCES profile_grouped_conv_fwd.cpp)
list(APPEND PROFILER_SOURCES profile_grouped_conv_bwd_data.cpp)
list(APPEND PROFILER_SOURCES profile_grouped_conv_bwd_weight.cpp)
list(APPEND PROFILER_OPS profile_gemm_universal.cpp)
list(APPEND PROFILER_OPS profile_grouped_conv_fwd.cpp)
list(APPEND PROFILER_OPS profile_grouped_conv_bwd_data.cpp)
list(APPEND PROFILER_OPS profile_grouped_conv_bwd_weight.cpp)
endif()
if(DL_KERNELS)
list(APPEND PROFILER_SOURCES profile_batched_gemm_multi_d.cpp)
list(APPEND PROFILER_SOURCES profile_grouped_conv_bwd_weight.cpp)
list(APPEND PROFILER_OPS profile_batched_gemm_multi_d.cpp)
list(APPEND PROFILER_OPS profile_grouped_conv_bwd_weight.cpp)
endif()
set(PROFILER_SOURCES profiler.cpp)
foreach(SOURCE ${PROFILER_OPS})
string(REGEX REPLACE "profile_(.+)\.cpp" "\\1" OP_NAME ${SOURCE})
if (OP_NAME STREQUAL "")
message(FATAL_ERROR "Unexpected source file name: ${SOURCE}")
endif()
if("${OP_NAME}" MATCHES "${CK_PROFILER_OP_FILTER}")
list(APPEND PROFILER_SOURCES ${SOURCE})
endif()
endforeach()
message(STATUS "ckProfiler sources: ${PROFILER_SOURCES}")
set(PROFILER_EXECUTABLE ckProfiler)
add_executable(${PROFILER_EXECUTABLE} ${PROFILER_SOURCES})
target_compile_options(${PROFILER_EXECUTABLE} PRIVATE -Wno-global-constructors)
# flags to compress the library
if(NOT WIN32 AND ${hip_VERSION_FLAT} GREATER 600241132)
message("Adding --offload-compress flag for ${PROFILER_EXECUTABLE}")
message(STATUS "Adding --offload-compress flag for ${PROFILER_EXECUTABLE}")
target_compile_options(${PROFILER_EXECUTABLE} PRIVATE --offload-compress)
endif()
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE utility getopt::getopt)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_fwd_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_bwd_data_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_bwd_gamma_beta_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_softmax_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_reduce_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batchnorm_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_pool2d_fwd_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_pool3d_fwd_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_avg_pool2d_bwd_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_avg_pool3d_bwd_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_max_pool_bwd_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_image_to_column_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_column_to_image_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_transpose_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_permute_scale_instance)
set(DEVICE_INSTANCES "")
list(APPEND DEVICE_INSTANCES device_gemm_instance)
list(APPEND DEVICE_INSTANCES device_normalization_fwd_instance)
list(APPEND DEVICE_INSTANCES device_normalization_bwd_data_instance)
list(APPEND DEVICE_INSTANCES device_normalization_bwd_gamma_beta_instance)
list(APPEND DEVICE_INSTANCES device_softmax_instance)
list(APPEND DEVICE_INSTANCES device_reduce_instance)
list(APPEND DEVICE_INSTANCES device_batchnorm_instance)
list(APPEND DEVICE_INSTANCES device_pool2d_fwd_instance)
list(APPEND DEVICE_INSTANCES device_pool3d_fwd_instance)
list(APPEND DEVICE_INSTANCES device_avg_pool2d_bwd_instance)
list(APPEND DEVICE_INSTANCES device_avg_pool3d_bwd_instance)
list(APPEND DEVICE_INSTANCES device_max_pool_bwd_instance)
list(APPEND DEVICE_INSTANCES device_image_to_column_instance)
list(APPEND DEVICE_INSTANCES device_column_to_image_instance)
list(APPEND DEVICE_INSTANCES device_transpose_instance)
list(APPEND DEVICE_INSTANCES device_permute_scale_instance)
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9")
if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_bilinear_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_scale_instance)
list(APPEND DEVICE_INSTANCES device_contraction_bilinear_instance)
list(APPEND DEVICE_INSTANCES device_contraction_scale_instance)
endif()
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_add_fastgelu_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_fastgelu_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_gemm_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_add_relu_gemm_add_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_streamk_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_fastgelu_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_relu_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_silu_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_relu_add_layernorm_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_fixed_nk_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_fastgelu_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_tile_loop_instance)
list(APPEND DEVICE_INSTANCES device_gemm_add_instance)
list(APPEND DEVICE_INSTANCES device_gemm_add_add_fastgelu_instance)
list(APPEND DEVICE_INSTANCES device_gemm_fastgelu_instance)
list(APPEND DEVICE_INSTANCES device_batched_gemm_gemm_instance)
list(APPEND DEVICE_INSTANCES device_batched_gemm_add_relu_gemm_add_instance)
list(APPEND DEVICE_INSTANCES device_grouped_gemm_instance)
list(APPEND DEVICE_INSTANCES device_gemm_streamk_instance)
list(APPEND DEVICE_INSTANCES device_gemm_add_fastgelu_instance)
list(APPEND DEVICE_INSTANCES device_gemm_add_relu_instance)
list(APPEND DEVICE_INSTANCES device_gemm_add_silu_instance)
list(APPEND DEVICE_INSTANCES device_gemm_add_relu_add_layernorm_instance)
list(APPEND DEVICE_INSTANCES device_grouped_gemm_fixed_nk_instance)
list(APPEND DEVICE_INSTANCES device_grouped_gemm_fastgelu_instance)
list(APPEND DEVICE_INSTANCES device_grouped_gemm_tile_loop_instance)
endif()
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_reduce_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_add_instance)
if(SUPPORTED_GPU_TARGETS MATCHES "gfx94")
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_multiply_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_multiply_wp_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_ab_scale_instance)
list(APPEND DEVICE_INSTANCES device_batched_gemm_instance)
list(APPEND DEVICE_INSTANCES device_batched_gemm_reduce_instance)
list(APPEND DEVICE_INSTANCES device_gemm_multiply_add_instance)
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9[45]")
list(APPEND DEVICE_INSTANCES device_gemm_multiply_multiply_instance)
list(APPEND DEVICE_INSTANCES device_gemm_multiply_multiply_wp_instance)
list(APPEND DEVICE_INSTANCES device_gemm_ab_scale_instance)
endif()
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_splitk_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_b_scale_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_b_scale_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_universal_batched_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_universal_reduce_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_universal_streamk_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_multiply_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_reduce_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bias_add_reduce_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_bias_relu_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_bias_relu_add_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_fwd_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv1d_bwd_data_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv3d_bwd_data_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_bwd_data_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_bwd_weight_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_weight_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_fwd_convscale_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_fwd_convinvscale_instance)
list(APPEND DEVICE_INSTANCES device_gemm_splitk_instance)
list(APPEND DEVICE_INSTANCES device_gemm_b_scale_instance)
list(APPEND DEVICE_INSTANCES device_batched_gemm_b_scale_instance)
list(APPEND DEVICE_INSTANCES device_gemm_universal_batched_instance)
list(APPEND DEVICE_INSTANCES device_gemm_universal_reduce_instance)
list(APPEND DEVICE_INSTANCES device_gemm_universal_streamk_instance)
list(APPEND DEVICE_INSTANCES device_gemm_add_multiply_instance)
list(APPEND DEVICE_INSTANCES device_gemm_reduce_instance)
list(APPEND DEVICE_INSTANCES device_gemm_bias_add_reduce_instance)
list(APPEND DEVICE_INSTANCES device_conv2d_fwd_instance)
list(APPEND DEVICE_INSTANCES device_conv2d_fwd_bias_relu_instance)
list(APPEND DEVICE_INSTANCES device_conv2d_fwd_bias_relu_add_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv1d_fwd_instance)
list(APPEND DEVICE_INSTANCES device_conv1d_bwd_data_instance)
list(APPEND DEVICE_INSTANCES device_conv3d_bwd_data_instance)
list(APPEND DEVICE_INSTANCES device_conv2d_bwd_data_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv1d_bwd_weight_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv2d_bwd_weight_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_convscale_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_convinvscale_instance)
endif()
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12")
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bilinear_instance)
list(APPEND DEVICE_INSTANCES device_gemm_bilinear_instance)
endif()
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_universal_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_fwd_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_data_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_data_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_fwd_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_weight_instance)
list(APPEND DEVICE_INSTANCES device_gemm_universal_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv2d_bwd_data_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_data_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_weight_instance)
endif()
if(DL_KERNELS)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_multi_d_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_bwd_weight_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_weight_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_weight_instance)
list(APPEND DEVICE_INSTANCES device_batched_gemm_multi_d_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv1d_bwd_weight_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv2d_bwd_weight_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_weight_instance)
endif()
set(PROFILER_LIBS utility getopt::getopt)
foreach(LIB ${DEVICE_INSTANCES})
string(REGEX REPLACE "device_(.+)_instance" "\\1" INSTANCE_NAME ${LIB})
if (INSTANCE_NAME STREQUAL "")
message(FATAL_ERROR "Unexpected kernel instance name: ${LIB}")
endif()
if("${INSTANCE_NAME}" MATCHES "${CK_PROFILER_INSTANCE_FILTER}")
list(APPEND PROFILER_LIBS ${LIB})
endif()
endforeach()
message(STATUS "ckProfiler libs: ${PROFILER_LIBS}")
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE ${PROFILER_LIBS})
rocm_install(TARGETS ${PROFILER_EXECUTABLE} COMPONENT profiler)

View File

@@ -5,7 +5,7 @@ find_package(SQLite3 REQUIRED)
execute_process(
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py
--working_path ${CMAKE_CURRENT_BINARY_DIR}
--config_json ${CMAKE_CURRENT_LIST_DIR}/configs/user_provided_config.json
# --config_json ${CMAKE_CURRENT_LIST_DIR}/configs/user_provided_config.json
--list_blobs
RESULT_VARIABLE ret
)
@@ -33,7 +33,7 @@ add_custom_command(
OUTPUT ${GEMM_CODEGEN_BLOBS}
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py
--working_path ${CMAKE_CURRENT_BINARY_DIR}
--config_json ${CMAKE_CURRENT_LIST_DIR}/configs/user_provided_config.json
# --config_json ${CMAKE_CURRENT_LIST_DIR}/configs/user_provided_config.json
--gen_blobs
)

View File

@@ -6,7 +6,7 @@ CK Tile Engine GEMM is used to generate and run GEMM kernels with different comb
User can provide kernel configuration such as tile size, warp size, padding, pipeline, scheduler and epilogue in the config file with limited values. For reference please see `./configs/user_provided_config.json`.
The Tile engine also has a default kernel configuration for providing range of configuration parameter values, which helps users who lack kernel development experience to benchmark For reference please see in `./configs/default_config.json`
The Tile engine also has a default kernel configuration for providing range of configuration parameter values, which helps users who lack kernel development experience to benchmark. For reference please see in `./configs/default_config.json`
If user does not provide kernel configuration, the tile engine uses default kernel configuration to generate kernel instances and benchmark.

View File

@@ -6,6 +6,7 @@
#include <iostream>
#include <string>
#include <fstream>
#include <sstream>
#include <stdexcept>
#include "gemm_host_api.hpp"
@@ -39,6 +40,13 @@ struct GemmProblem
bool structured_sparsity_;
std::string to_json() const
{
std::ostringstream oss;
oss << *this;
return oss.str();
}
friend std::ostream& operator<<(std::ostream& os, const GemmProblem& problem)
{
os << "{\n"

View File

@@ -22,18 +22,18 @@ class GemmProfiler
return instance;
}
bool is_problem_record_cache(const GemmProblem& gemm_problem)
bool if_should_profile(const GemmProblem& gemm_problem)
{
if(setting_.enable_profile_cache_)
{
if(!cache_db_->check_if_record_problem(
get_rocm_version(), ck_tile::get_device_name(), gemm_problem))
{
return false;
return true;
}
else
{
auto [name, perf_result] = cache_db_->query_cache(
const auto& [name, perf_result] = cache_db_->query_cache(
get_rocm_version(), ck_tile::get_device_name(), gemm_problem);
KernelInstance kernel_instance;
kernel_instance.problem_ = gemm_problem;
@@ -41,16 +41,16 @@ class GemmProfiler
kernel_instance.perf_result_.latency_ = perf_result.latency_;
kernel_instance.perf_result_.tflops_ = perf_result.tflops_;
kernel_instance.perf_result_.bandwidth_ = perf_result.bandwidth_;
std::cout << "Skip this problem for " << gemm_problem
std::cout << "Skip this instance for " << kernel_instance
<< ", Because it has already been recorded in the cache database"
<< std::endl;
kernel_instances_.emplace_back(kernel_instance);
return true;
return false;
}
}
else
{
return false;
return true;
}
}
@@ -69,7 +69,7 @@ class GemmProfiler
gemm_problem.stride_c_ = ck_tile::get_default_stride(
gemm_problem.m_, gemm_problem.n_, gemm_problem.stride_c_, is_row_major(layout_c));
if(is_problem_record_cache(gemm_problem))
if(!if_should_profile(gemm_problem))
return;
ck_tile::HostTensor<ADataType> a_m_k(ck_tile::host_tensor_descriptor(
@@ -157,7 +157,7 @@ class GemmProfiler
gemm_problem.stride_c_);
}
for(auto callable : callables)
for(auto& callable : callables)
{
auto kernel_run_result = callable(gemm_args,
ck_tile::stream_config{nullptr,
@@ -186,7 +186,7 @@ class GemmProfiler
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
const std::tuple<std::string, float>& kernel_run_result)
{
auto [name, avg_time] = kernel_run_result;
const auto& [name, avg_time] = kernel_run_result;
KernelInstance kernel_instance{name, gemm_problem, {-1.0f, -1.0f, -1.0f}};
@@ -290,7 +290,7 @@ class GemmProfiler
return kernel_instance;
}
GemmProfiler(const GemmProfiler&) = delete;
GemmProfiler(const GemmProfiler&) = delete;
GemmProfiler& operator=(const GemmProfiler&) = delete;
private:

View File

@@ -3,7 +3,6 @@
#include <sqlite3.h>
#include <tuple>
#include <sstream>
#include "benchmark_gemm.hpp"
@@ -72,9 +71,9 @@ class ProfileCacheDB
try
{
exec_direct("PRAGMA journal_mode = WAL");
exec_direct("PRAGMA synchronous = NORMAL");
exec_direct("PRAGMA foreign_keys = ON");
execute("PRAGMA journal_mode = WAL");
execute("PRAGMA synchronous = NORMAL");
execute("PRAGMA foreign_keys = ON");
constexpr const char* schema = R"sql(
CREATE TABLE IF NOT EXISTS gemm (
@@ -87,12 +86,9 @@ class ProfileCacheDB
tflops REAL CHECK(tflops > 0),
bandwidth REAL CHECK(bandwidth > 0)
);
CREATE INDEX IF NOT EXISTS idx_latency ON gemm(latency);
CREATE INDEX IF NOT EXISTS idx_tflops_desc ON gemm(tflops DESC);
CREATE INDEX IF NOT EXISTS idx_bandwidth_desc ON gemm(bandwidth DESC);
)sql";
exec_direct(schema);
execute(schema);
}
catch(...)
{
@@ -105,12 +101,12 @@ class ProfileCacheDB
const GemmProblem& gemm_problem)
{
constexpr const char* sql = R"sql(
SELECT 1 FROM gemm
WHERE rocm_version=?
AND device_name=?
AND problem=?
LIMIT 1
)sql";
SELECT 1 FROM gemm
WHERE rocm_version=?
AND device_name=?
AND problem=?
LIMIT 1
)sql";
StmtWrapper stmt(db_ptr_.get(), sql);
sqlite3_stmt* raw_stmt = stmt;
@@ -120,23 +116,15 @@ class ProfileCacheDB
db_ptr_.get());
CHECK_SQLITE3(sqlite3_bind_text(raw_stmt, idx++, device_name.c_str(), -1, SQLITE_TRANSIENT),
db_ptr_.get());
std::ostringstream oss;
oss << gemm_problem;
auto problem_json = oss.str();
CHECK_SQLITE3(
sqlite3_bind_text(raw_stmt, idx++, problem_json.c_str(), -1, SQLITE_TRANSIENT),
db_ptr_.get());
CHECK_SQLITE3(sqlite3_bind_text(
raw_stmt, idx++, gemm_problem.to_json().c_str(), -1, SQLITE_TRANSIENT),
db_ptr_.get());
int rc;
CHECK_SQLITE3_RC(sqlite3_step(raw_stmt), db_ptr_.get(), rc);
CHECK_SQLITE3(sqlite3_reset(raw_stmt), db_ptr_.get());
CHECK_SQLITE3(sqlite3_clear_bindings(raw_stmt), db_ptr_.get());
std::cout << "Query params:\n"
<< "rocm_version: " << rocm_version << "\n"
<< "device_name: " << device_name << "\n"
<< "problem_json: " << problem_json << std::endl;
if(rc == SQLITE_DONE)
{
std::cout << "No matching records found" << std::endl;
@@ -149,8 +137,9 @@ class ProfileCacheDB
const GemmProblem& gemm_problem)
{
constexpr const char* sql = R"sql(
SELECT latency, tflops, bandwidth FROM gemm
WHERE rocm_version=? AND device_name=?
SELECT instance_name, latency, tflops, bandwidth FROM gemm
WHERE rocm_version=?
AND device_name=?
AND problem=?
LIMIT 1
)sql";
@@ -164,12 +153,9 @@ class ProfileCacheDB
db_ptr_.get());
CHECK_SQLITE3(sqlite3_bind_text(raw_stmt, idx++, device_name.c_str(), -1, SQLITE_TRANSIENT),
db_ptr_.get());
std::ostringstream oss;
oss << gemm_problem;
auto problem_json = oss.str();
CHECK_SQLITE3(sqlite3_bind_text(
stmt, idx++, problem_json.c_str(), problem_json.size(), SQLITE_TRANSIENT),
db_ptr_.get());
CHECK_SQLITE3(
sqlite3_bind_text(stmt, idx++, gemm_problem.to_json().c_str(), -1, SQLITE_TRANSIENT),
db_ptr_.get());
int rc;
CHECK_SQLITE3_RC(sqlite3_step(raw_stmt), db_ptr_.get(), rc);
@@ -195,15 +181,25 @@ class ProfileCacheDB
const std::string& device_name,
const std::vector<KernelInstance>& kernen_instnaces)
{
exec_direct("BEGIN TRANSACTION");
execute("BEGIN TRANSACTION");
try
{
constexpr const char* sql = R"sql(
INSERT INTO gemm
(rocm_version, device_name,
problem, instance_name,
latency, tflops, bandwidth)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)
(rocm_version,
device_name,
problem,
instance_name,
latency,
tflops,
bandwidth)
VALUES (?1,
?2,
?3,
?4,
?5,
?6,
?7)
)sql";
StmtWrapper stmt(db_ptr_.get(), sql);
@@ -218,15 +214,10 @@ class ProfileCacheDB
CHECK_SQLITE3(
sqlite3_bind_text(raw_stmt, idx++, device_name.c_str(), -1, SQLITE_TRANSIENT),
db_ptr_.get());
std::ostringstream oss;
oss << item.problem_;
auto problem_json = oss.str();
CHECK_SQLITE3(sqlite3_bind_text(raw_stmt,
idx++,
problem_json.c_str(),
problem_json.size(),
SQLITE_TRANSIENT),
db_ptr_.get());
CHECK_SQLITE3(
sqlite3_bind_text(
raw_stmt, idx++, item.problem_.to_json().c_str(), -1, SQLITE_TRANSIENT),
db_ptr_.get());
CHECK_SQLITE3(
sqlite3_bind_text(raw_stmt, idx++, item.name_.c_str(), -1, SQLITE_TRANSIENT),
db_ptr_.get());
@@ -242,17 +233,17 @@ class ProfileCacheDB
CHECK_SQLITE3(sqlite3_reset(raw_stmt), db_ptr_.get());
CHECK_SQLITE3(sqlite3_clear_bindings(raw_stmt), db_ptr_.get());
}
exec_direct("COMMIT");
execute("COMMIT");
}
catch(...)
{
exec_direct("ROLLBACK");
execute("ROLLBACK");
throw;
}
}
private:
void exec_direct(const char* sql)
void execute(const char* sql)
{
CHECK_SQLITE3(sqlite3_exec(db_ptr_.get(), sql, nullptr, nullptr, nullptr), db_ptr_.get());
}