mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-02 13:17:36 +00:00
merge upstream
This commit is contained in:
@@ -18,7 +18,10 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
|
||||
* Added support for Split K for grouped convolution backward data.
|
||||
* Added logit soft-capping support for fMHA forward kernels.
|
||||
* Added benchmarking support for tile engine GEMM.
|
||||
<<<<<<< HEAD
|
||||
* Added profiling cache support for tile engine GEMM.
|
||||
=======
|
||||
>>>>>>> upstream/develop
|
||||
|
||||
### Optimized
|
||||
|
||||
|
||||
@@ -58,7 +58,8 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
|
||||
{F_lse},
|
||||
{F_dropout},
|
||||
{F_squant},
|
||||
{F_occupancy}>;
|
||||
{F_occupancy},
|
||||
{F_skip}>;
|
||||
|
||||
using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>;
|
||||
|
||||
@@ -94,7 +95,7 @@ using fmha_kernel_{F_idx} =
|
||||
ck_tile::FmhaFwdKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>;
|
||||
|
||||
using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
|
||||
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
|
||||
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
@@ -129,9 +130,9 @@ FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) &&
|
||||
FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) &&
|
||||
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
|
||||
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
|
||||
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>;
|
||||
return fmha_fwd_<trait_>(s, a);
|
||||
}}
|
||||
"""
|
||||
@@ -160,11 +161,12 @@ class FmhaFwdApiTrait:
|
||||
skpad : str
|
||||
dpad : str
|
||||
dvpad : str
|
||||
skip : str
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-'+\
|
||||
f'{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}'
|
||||
f'{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}'
|
||||
|
||||
@property
|
||||
def scheck(self) -> str:
|
||||
@@ -227,6 +229,7 @@ class FmhaFwdPipeline:
|
||||
F_dropout : str #
|
||||
F_squant : str #
|
||||
F_mask : str # value from MASK_MAP
|
||||
F_skip : str # true/false
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
@@ -262,8 +265,12 @@ class FmhaFwdPipeline:
|
||||
if self.F_dropout == 't' : n += '_dropout'
|
||||
else: n += '_ndropout'
|
||||
|
||||
if self.F_skip == 't' : n += '_skip'
|
||||
else: n += '_nskip'
|
||||
|
||||
if self.F_squant == 't' : n += '_squant'
|
||||
else: n += '_nsquant'
|
||||
|
||||
return n
|
||||
|
||||
class FmhaFwdApiPool:
|
||||
@@ -293,7 +300,7 @@ class FmhaFwdApiPool:
|
||||
inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout],
|
||||
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask],
|
||||
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias],
|
||||
F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout] ,
|
||||
F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout], F_skip=BOOL_MAP[trait.skip],
|
||||
F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
|
||||
F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
|
||||
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max,
|
||||
@@ -381,6 +388,7 @@ class FmhaFwdKernel:
|
||||
F_lse = BOOL_MAP[self.F_pipeline.F_lse],
|
||||
F_dropout = BOOL_MAP[self.F_pipeline.F_dropout],
|
||||
F_squant = BOOL_MAP[self.F_pipeline.F_squant],
|
||||
F_skip = BOOL_MAP[self.F_pipeline.F_skip],
|
||||
F_occupancy = self.F_tile.F_occupancy,
|
||||
F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag],
|
||||
F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
|
||||
@@ -419,7 +427,8 @@ class FmhaFwdKernel:
|
||||
spad=self.F_pipeline.F_spad,
|
||||
skpad=self.F_pipeline.F_skpad,
|
||||
dpad=self.F_pipeline.F_dpad,
|
||||
dvpad=self.F_pipeline.F_dvpad)
|
||||
dvpad=self.F_pipeline.F_dvpad,
|
||||
skip=self.F_pipeline.F_skip)
|
||||
|
||||
# TODO: design a more practical way to do it
|
||||
# this is current supported tile size per hdim
|
||||
@@ -453,36 +462,36 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
|
||||
squant = 't' if dtype == 'fp8' else 'f'
|
||||
pipelines = []
|
||||
if dtype in ['fp16', 'bf16']:
|
||||
for logits, mask, bias, lse, dropout in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]):
|
||||
for logits, mask, bias, lse, dropout, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"]):
|
||||
if hdim == 256:
|
||||
# if True:
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip))
|
||||
# the below two is used for hdim vectorize load
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip))
|
||||
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip))
|
||||
else:
|
||||
if bias == "bias":
|
||||
# TODO: rocm 6.2 compiler problem if using qr_async for bias case
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip))
|
||||
else:
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip))
|
||||
if receipt == 1 and bias != "bias":
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) # TODO: cover arbitraty hdim
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) # TODO: cover arbitraty hdim
|
||||
elif dtype in ['fp8', 'bf8']:
|
||||
# no need lse/dropout kernels
|
||||
for logits, mask, bias in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, 'f', 'f', squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, 'f', 'f', squant, mask, 'f'))
|
||||
elif dtype in ['fp8fp16', 'fp8bf16']:
|
||||
# TODO
|
||||
None
|
||||
@@ -532,6 +541,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
cond &= pipeline.F_bias in ['no', 'alibi']
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
cond &= pipeline.F_skip == 'f'
|
||||
if not cond:
|
||||
continue
|
||||
# PyTorch integration
|
||||
@@ -540,6 +550,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
cond &= pipeline.F_bias in ['no', 'bias']
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
cond &= pipeline.F_skip == 'f'
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter(mha_fwd) integration
|
||||
@@ -565,6 +576,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
if not cond:
|
||||
continue
|
||||
|
||||
api_pool.register_traits(k.api_trait())
|
||||
gen.append(k)
|
||||
|
||||
|
||||
@@ -169,6 +169,7 @@ struct fmha_fwd_args
|
||||
ck_tile::index_t window_size_left;
|
||||
ck_tile::index_t window_size_right;
|
||||
ck_tile::index_t mask_type;
|
||||
ck_tile::index_t min_seqlen_q;
|
||||
|
||||
float p_drop;
|
||||
bool s_randval;
|
||||
@@ -433,6 +434,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type,
|
||||
args.min_seqlen_q,
|
||||
args.p_drop,
|
||||
args.s_randval,
|
||||
args.drop_seed_offset);
|
||||
@@ -837,7 +839,8 @@ template <ck_tile::index_t HDim_,
|
||||
bool kPadS_,
|
||||
bool kPadSK_,
|
||||
bool kPadD_,
|
||||
bool kPadDv_>
|
||||
bool kPadDv_,
|
||||
bool kSkipMinSeqlenQ_ = false>
|
||||
struct fmha_fwd_traits_
|
||||
{
|
||||
static constexpr ck_tile::index_t HDim = HDim_;
|
||||
@@ -861,6 +864,7 @@ struct fmha_fwd_traits_
|
||||
static constexpr bool kPadSK = kPadSK_;
|
||||
static constexpr bool kPadD = kPadD_;
|
||||
static constexpr bool kPadDv = kPadDv_;
|
||||
static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_;
|
||||
};
|
||||
|
||||
template <typename Traits_>
|
||||
@@ -995,6 +999,7 @@ struct fmha_fwd_traits
|
||||
bool has_lse;
|
||||
bool has_dropout;
|
||||
bool do_fp8_static_quant;
|
||||
bool skip_min_seqlen_q = false;
|
||||
// TODO: padding check is inside this api
|
||||
};
|
||||
float fmha_fwd(fmha_fwd_traits, fmha_fwd_args, const ck_tile::stream_config&);
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -393,8 +393,10 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
|
||||
{
|
||||
const index_t GemmM = K;
|
||||
const index_t GemmN = C * X;
|
||||
const auto PadGemmM = MPerBlock - GemmM % MPerBlock;
|
||||
const auto PadGemmN = NPerBlock - GemmN % NPerBlock;
|
||||
const auto PadGemmM =
|
||||
GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock;
|
||||
const auto PadGemmN =
|
||||
GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
wei_grid_desc,
|
||||
@@ -432,8 +434,10 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
|
||||
{
|
||||
const index_t GemmM = K;
|
||||
const index_t GemmN = C * X * Y;
|
||||
const auto PadGemmM = MPerBlock - GemmM % MPerBlock;
|
||||
const auto PadGemmN = NPerBlock - GemmN % NPerBlock;
|
||||
const auto PadGemmM =
|
||||
GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock;
|
||||
const auto PadGemmN =
|
||||
GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
wei_grid_desc,
|
||||
@@ -472,8 +476,10 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
|
||||
{
|
||||
const index_t GemmM = K;
|
||||
const index_t GemmN = C * X * Y * Z;
|
||||
const auto PadGemmM = MPerBlock - GemmM % MPerBlock;
|
||||
const auto PadGemmN = NPerBlock - GemmN % NPerBlock;
|
||||
const auto PadGemmM =
|
||||
GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock;
|
||||
const auto PadGemmN =
|
||||
GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
wei_grid_desc,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -208,8 +208,8 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
|
||||
const index_t GemmM = K;
|
||||
const index_t GemmN = C * Z * X * Y;
|
||||
|
||||
const auto PadGemmM = MPerBlock - GemmM % MPerBlock;
|
||||
const auto PadGemmN = NPerBlock - GemmN % NPerBlock;
|
||||
const auto PadGemmM = GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock;
|
||||
const auto PadGemmN = GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock;
|
||||
|
||||
const index_t GemmK0 =
|
||||
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock) * K0PerBlock;
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -166,8 +166,8 @@ struct TransformConvBwdWeightToGemm
|
||||
const index_t GemmM = K;
|
||||
const index_t GemmN = C * X;
|
||||
|
||||
const auto PadGemmM = MPerBlock - GemmM % MPerBlock;
|
||||
const auto PadGemmN = NPerBlock - GemmN % NPerBlock;
|
||||
const auto PadGemmM = GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock;
|
||||
const auto PadGemmN = GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock;
|
||||
|
||||
const index_t GemmKBatch = batch_k;
|
||||
const index_t GemmK0 =
|
||||
@@ -365,8 +365,8 @@ struct TransformConvBwdWeightToGemm
|
||||
const index_t GemmM = K;
|
||||
const index_t GemmN = C * X * Y;
|
||||
|
||||
const auto PadGemmM = MPerBlock - GemmM % MPerBlock;
|
||||
const auto PadGemmN = NPerBlock - GemmN % NPerBlock;
|
||||
const auto PadGemmM = GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock;
|
||||
const auto PadGemmN = GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock;
|
||||
|
||||
const index_t GemmKBatch = batch_k;
|
||||
const index_t GemmK0 =
|
||||
@@ -558,8 +558,8 @@ struct TransformConvBwdWeightToGemm
|
||||
const index_t GemmM = K;
|
||||
const index_t GemmN = C * Z * X * Y;
|
||||
|
||||
const auto PadGemmM = MPerBlock - GemmM % MPerBlock;
|
||||
const auto PadGemmN = NPerBlock - GemmN % NPerBlock;
|
||||
const auto PadGemmM = GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock;
|
||||
const auto PadGemmN = GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock;
|
||||
|
||||
const index_t GemmKBatch = batch_k;
|
||||
const index_t GemmK0 =
|
||||
|
||||
@@ -346,8 +346,8 @@ struct TransformConvBwdWeightToGemmV2
|
||||
const index_t GemmM = K * NumGroupsToMerge;
|
||||
const index_t GemmN = C * X * NumGroupsToMerge;
|
||||
|
||||
const auto PadGemmM = MPerBlock - GemmM % MPerBlock;
|
||||
const auto PadGemmN = NPerBlock - GemmN % NPerBlock;
|
||||
const auto PadGemmM = GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock;
|
||||
const auto PadGemmN = GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock;
|
||||
|
||||
const index_t GemmKBatch = batch_k;
|
||||
const index_t GemmK0 =
|
||||
@@ -534,8 +534,8 @@ struct TransformConvBwdWeightToGemmV2
|
||||
const index_t GemmM = K * NumGroupsToMerge;
|
||||
const index_t GemmN = C * X * Y * NumGroupsToMerge;
|
||||
|
||||
const auto PadGemmM = MPerBlock - GemmM % MPerBlock;
|
||||
const auto PadGemmN = NPerBlock - GemmN % NPerBlock;
|
||||
const auto PadGemmM = GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock;
|
||||
const auto PadGemmN = GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock;
|
||||
|
||||
const index_t GemmKBatch = batch_k;
|
||||
const index_t GemmK0 =
|
||||
@@ -737,8 +737,8 @@ struct TransformConvBwdWeightToGemmV2
|
||||
const index_t GemmM = K * NumGroupsToMerge;
|
||||
const index_t GemmN = C * Z * X * Y * NumGroupsToMerge;
|
||||
|
||||
const auto PadGemmM = MPerBlock - GemmM % MPerBlock;
|
||||
const auto PadGemmN = NPerBlock - GemmN % NPerBlock;
|
||||
const auto PadGemmM = GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock;
|
||||
const auto PadGemmN = GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock;
|
||||
|
||||
const index_t GemmKBatch = batch_k;
|
||||
const index_t GemmK0 =
|
||||
|
||||
@@ -55,8 +55,8 @@
|
||||
#include "ck_tile/core/tensor/tile_distribution_encoding.hpp"
|
||||
#include "ck_tile/core/tensor/tile_elementwise.hpp"
|
||||
#include "ck_tile/core/tensor/tile_scatter_gather.hpp"
|
||||
#include "ck_tile/core/tensor/tile_window_base.hpp"
|
||||
#include "ck_tile/core/tensor/tile_window.hpp"
|
||||
#include "ck_tile/core/tensor/tile_window_base.hpp"
|
||||
#include "ck_tile/core/tensor/tile_window_linear.hpp"
|
||||
#include "ck_tile/core/tensor/tile_window_utils.hpp"
|
||||
#include "ck_tile/core/tensor/transpose_tile.hpp"
|
||||
|
||||
@@ -35,4 +35,5 @@
|
||||
#include "ck_tile/host/reference/reference_softmax.hpp"
|
||||
#include "ck_tile/host/reference/reference_topk.hpp"
|
||||
#include "ck_tile/host/stream_config.hpp"
|
||||
#include "ck_tile/host/stream_utils.hpp"
|
||||
#include "ck_tile/host/timer.hpp"
|
||||
|
||||
@@ -53,6 +53,8 @@ struct FmhaFwdKernel
|
||||
static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
|
||||
static constexpr bool kHasDropout = FmhaPipeline::kHasDropout;
|
||||
static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
|
||||
static constexpr bool kSkipMinSeqlenQ = FmhaPipeline::Problem::kSkipMinSeqlenQ;
|
||||
|
||||
using AttentionVariant = ck_tile::remove_cvref_t<typename FmhaPipeline::AttentionVariant>;
|
||||
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
|
||||
static constexpr bool kHasMask = FmhaMask::IsMasking;
|
||||
@@ -257,6 +259,11 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t batch_stride_randval = 0;
|
||||
};
|
||||
|
||||
struct FmhaFwdSkipMinSeqlenQKargs
|
||||
{
|
||||
ck_tile::index_t min_seqlen_q = 0;
|
||||
};
|
||||
|
||||
struct FmhaFwdBatchModeKargs
|
||||
: FmhaFwdCommonKargs,
|
||||
std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
|
||||
@@ -287,7 +294,8 @@ struct FmhaFwdKernel
|
||||
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
|
||||
std::conditional_t<kDoFp8StaticQuant, FmhaFwdFp8StaticQuantKargs, FmhaFwdEmptyKargs<3>>,
|
||||
std::conditional_t<kHasDropout, FmhaFwdCommonDropoutKargs, FmhaFwdEmptyKargs<4>>,
|
||||
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>
|
||||
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>,
|
||||
std::conditional_t<kSkipMinSeqlenQ, FmhaFwdSkipMinSeqlenQKargs, FmhaFwdEmptyKargs<6>>
|
||||
{
|
||||
const int32_t* seqstart_q_ptr;
|
||||
const int32_t* seqstart_k_ptr;
|
||||
@@ -664,6 +672,7 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t mask_type,
|
||||
ck_tile::index_t min_seqlen_q,
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
|
||||
@@ -698,6 +707,7 @@ struct FmhaFwdKernel
|
||||
{}, // placeholder for fp8_static_quant args
|
||||
{}, // placeholder for dropout
|
||||
{}, // placeholder for logits_soft_cap
|
||||
{}, // placeholder for min_seqlen_q
|
||||
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
|
||||
reinterpret_cast<const int32_t*>(seqstart_k_ptr),
|
||||
reinterpret_cast<const int32_t*>(seqlen_k_ptr)};
|
||||
@@ -753,6 +763,10 @@ struct FmhaFwdKernel
|
||||
{
|
||||
kargs.init_logits_soft_cap(logits_soft_cap);
|
||||
}
|
||||
if constexpr(kSkipMinSeqlenQ)
|
||||
{
|
||||
kargs.min_seqlen_q = min_seqlen_q;
|
||||
}
|
||||
|
||||
return kargs;
|
||||
}
|
||||
@@ -969,7 +983,15 @@ struct FmhaFwdKernel
|
||||
|
||||
const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
|
||||
|
||||
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
|
||||
if constexpr(kHasMask)
|
||||
{
|
||||
// assume that num_tile_n1 is always 1
|
||||
return ck_tile::make_tuple(gridDim.z - 1 - i_tile_m, i_tile_n, i_nhead, i_batch);
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -989,7 +1011,15 @@ struct FmhaFwdKernel
|
||||
|
||||
const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
|
||||
|
||||
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
|
||||
if constexpr(kHasMask)
|
||||
{
|
||||
// assume that num_tile_n1 is always 1
|
||||
return ck_tile::make_tuple(gridDim.x - 1 - i_tile_m, i_tile_n, i_nhead, i_batch);
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1053,6 +1083,14 @@ struct FmhaFwdKernel
|
||||
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
|
||||
kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
|
||||
|
||||
if constexpr(kSkipMinSeqlenQ)
|
||||
{
|
||||
if(kargs.seqlen_q <= kargs.min_seqlen_q)
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// # of required blocks is different in each groups, terminate unnecessary blocks
|
||||
// earlier
|
||||
if(kargs.seqlen_q <= i_m0)
|
||||
|
||||
@@ -561,7 +561,16 @@ struct FmhaFwdSplitKVKernel
|
||||
const index_t i_nhead = blockIdx.y;
|
||||
const index_t i_batch = blockIdx.z;
|
||||
|
||||
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_split, i_nhead, i_batch);
|
||||
if constexpr(kHasMask)
|
||||
{
|
||||
// assume that num_tile_n1 is always 1
|
||||
return ck_tile::make_tuple(
|
||||
(gridDim.x / kargs.num_splits) - 1 - i_tile_m, i_tile_n, i_split, i_nhead, i_batch);
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_split, i_nhead, i_batch);
|
||||
}
|
||||
}
|
||||
|
||||
__host__ static constexpr auto BlockSize() { return dim3(kBlockSize); }
|
||||
|
||||
@@ -53,6 +53,7 @@ struct BlockFmhaPipelineProblem
|
||||
static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
|
||||
static constexpr bool kHasLogitsSoftCap = Traits::kHasLogitsSoftCap;
|
||||
static constexpr bool kSkipMinSeqlenQ = Traits::kSkipMinSeqlenQ;
|
||||
static constexpr auto BiasEnum = Traits::BiasEnum;
|
||||
static constexpr bool kStoreLSE = Traits::kStoreLSE;
|
||||
static constexpr bool kHasDropout = Traits::kHasDropout;
|
||||
|
||||
@@ -19,7 +19,8 @@ template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
|
||||
bool kStoreLSE_,
|
||||
bool kHasDropout_,
|
||||
bool kDoFp8StaticQuant_,
|
||||
index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
|
||||
index_t kBlockPerCu_ = -1, /* overwrite occupancy if not -1 */
|
||||
bool kSkipMinSeqlenQ_ = false /* skip min seqlen q while chunked prefill */>
|
||||
struct TileFmhaTraits
|
||||
{
|
||||
static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
|
||||
@@ -33,6 +34,7 @@ struct TileFmhaTraits
|
||||
static constexpr bool kHasDropout = kHasDropout_;
|
||||
static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
|
||||
static constexpr index_t kBlockPerCu = kBlockPerCu_;
|
||||
static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_;
|
||||
};
|
||||
|
||||
template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
|
||||
|
||||
@@ -1,6 +1,16 @@
|
||||
# ckProfiler
|
||||
set(PROFILER_SOURCES
|
||||
profiler.cpp
|
||||
set(CK_PROFILER_OP_FILTER "" CACHE STRING "Filter for the operators to be profiled. Default is to include all")
|
||||
set(CK_PROFILER_INSTANCE_FILTER "" CACHE STRING "Filter for the kernels instances to be profiled. Default is to be the same as the operator filter")
|
||||
if (CK_PROFILER_OP_FILTER STREQUAL "")
|
||||
set(CK_PROFILER_OP_FILTER ".+")
|
||||
endif()
|
||||
if (CK_PROFILER_INSTANCE_FILTER STREQUAL "")
|
||||
set(CK_PROFILER_INSTANCE_FILTER ${CK_PROFILER_OP_FILTER})
|
||||
endif()
|
||||
message(STATUS "CK_PROFILER_OP_FILTER: ${CK_PROFILER_OP_FILTER}")
|
||||
message(STATUS "CK_PROFILER_INSTANCE_FILTER: ${CK_PROFILER_INSTANCE_FILTER}")
|
||||
|
||||
set(PROFILER_OPS
|
||||
profile_gemm.cpp
|
||||
profile_reduce.cpp
|
||||
profile_groupnorm_bwd_data.cpp
|
||||
@@ -26,161 +36,188 @@ set(PROFILER_SOURCES
|
||||
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9")
|
||||
if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES)
|
||||
list(APPEND PROFILER_SOURCES profile_contraction_bilinear.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_contraction_scale.cpp)
|
||||
list(APPEND PROFILER_OPS profile_contraction_bilinear.cpp)
|
||||
list(APPEND PROFILER_OPS profile_contraction_scale.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND PROFILER_SOURCES profile_gemm_reduce.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_batched_gemm_gemm.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_batched_gemm_add_relu_gemm_add.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_gemm_add.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_gemm_add_add_fastgelu.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_gemm_add_fastgelu.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_grouped_gemm.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_gemm_streamk.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_gemm_fastgelu.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_gemm_add_relu.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_gemm_add_silu.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_gemm_add_relu_add_layernorm.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_grouped_gemm_fixed_nk.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_grouped_gemm_fastgelu.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_grouped_gemm_tile_loop.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_grouped_gemm_multiply_tile_loop.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_reduce.cpp)
|
||||
list(APPEND PROFILER_OPS profile_batched_gemm_gemm.cpp)
|
||||
list(APPEND PROFILER_OPS profile_batched_gemm_add_relu_gemm_add.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_add.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_add_add_fastgelu.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_add_fastgelu.cpp)
|
||||
list(APPEND PROFILER_OPS profile_grouped_gemm.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_streamk.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_fastgelu.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_add_relu.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_add_silu.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_add_relu_add_layernorm.cpp)
|
||||
list(APPEND PROFILER_OPS profile_grouped_gemm_fixed_nk.cpp)
|
||||
list(APPEND PROFILER_OPS profile_grouped_gemm_fastgelu.cpp)
|
||||
list(APPEND PROFILER_OPS profile_grouped_gemm_tile_loop.cpp)
|
||||
list(APPEND PROFILER_OPS profile_grouped_gemm_multiply_tile_loop.cpp)
|
||||
endif()
|
||||
list(APPEND PROFILER_SOURCES profile_gemm_multiply_add.cpp)
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx94")
|
||||
list(APPEND PROFILER_SOURCES profile_gemm_multiply_multiply.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_gemm_multiply_multiply_wp.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_gemm_ab_scale.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_multiply_add.cpp)
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9[45]")
|
||||
list(APPEND PROFILER_OPS profile_gemm_multiply_multiply.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_multiply_multiply_wp.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_ab_scale.cpp)
|
||||
endif()
|
||||
list(APPEND PROFILER_SOURCES profile_batched_gemm.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_batched_gemm_reduce.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_gemm_add_multiply.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_gemm_bias_add_reduce.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_gemm_splitk.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_gemm_b_scale.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_batched_gemm_b_scale.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_gemm_universal_batched.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_gemm_universal_reduce.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_gemm_universal_streamk.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_conv_fwd_bias_relu.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_conv_fwd_bias_relu_add.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_conv_bwd_data.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_conv_fwd.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_grouped_conv_fwd_outelementop.cpp)
|
||||
list(APPEND PROFILER_OPS profile_batched_gemm.cpp)
|
||||
list(APPEND PROFILER_OPS profile_batched_gemm_reduce.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_add_multiply.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_bias_add_reduce.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_splitk.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_b_scale.cpp)
|
||||
list(APPEND PROFILER_OPS profile_batched_gemm_b_scale.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_universal_batched.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_universal_reduce.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_universal_streamk.cpp)
|
||||
list(APPEND PROFILER_OPS profile_conv_fwd_bias_relu.cpp)
|
||||
list(APPEND PROFILER_OPS profile_conv_fwd_bias_relu_add.cpp)
|
||||
list(APPEND PROFILER_OPS profile_conv_bwd_data.cpp)
|
||||
list(APPEND PROFILER_OPS profile_conv_fwd.cpp)
|
||||
list(APPEND PROFILER_OPS profile_grouped_conv_fwd_outelementop.cpp)
|
||||
|
||||
endif()
|
||||
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12" OR SUPPORTED_GPU_TARGETS MATCHES "gfx9")
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND PROFILER_SOURCES profile_gemm_bilinear.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_bilinear.cpp)
|
||||
endif()
|
||||
list(APPEND PROFILER_SOURCES profile_gemm_universal.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_grouped_conv_fwd.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_grouped_conv_bwd_data.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_grouped_conv_bwd_weight.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_universal.cpp)
|
||||
list(APPEND PROFILER_OPS profile_grouped_conv_fwd.cpp)
|
||||
list(APPEND PROFILER_OPS profile_grouped_conv_bwd_data.cpp)
|
||||
list(APPEND PROFILER_OPS profile_grouped_conv_bwd_weight.cpp)
|
||||
endif()
|
||||
|
||||
if(DL_KERNELS)
|
||||
list(APPEND PROFILER_SOURCES profile_batched_gemm_multi_d.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_grouped_conv_bwd_weight.cpp)
|
||||
list(APPEND PROFILER_OPS profile_batched_gemm_multi_d.cpp)
|
||||
list(APPEND PROFILER_OPS profile_grouped_conv_bwd_weight.cpp)
|
||||
endif()
|
||||
|
||||
set(PROFILER_SOURCES profiler.cpp)
|
||||
foreach(SOURCE ${PROFILER_OPS})
|
||||
string(REGEX REPLACE "profile_(.+)\.cpp" "\\1" OP_NAME ${SOURCE})
|
||||
if (OP_NAME STREQUAL "")
|
||||
message(FATAL_ERROR "Unexpected source file name: ${SOURCE}")
|
||||
endif()
|
||||
if("${OP_NAME}" MATCHES "${CK_PROFILER_OP_FILTER}")
|
||||
list(APPEND PROFILER_SOURCES ${SOURCE})
|
||||
endif()
|
||||
endforeach()
|
||||
message(STATUS "ckProfiler sources: ${PROFILER_SOURCES}")
|
||||
|
||||
set(PROFILER_EXECUTABLE ckProfiler)
|
||||
|
||||
add_executable(${PROFILER_EXECUTABLE} ${PROFILER_SOURCES})
|
||||
target_compile_options(${PROFILER_EXECUTABLE} PRIVATE -Wno-global-constructors)
|
||||
# flags to compress the library
|
||||
if(NOT WIN32 AND ${hip_VERSION_FLAT} GREATER 600241132)
|
||||
message("Adding --offload-compress flag for ${PROFILER_EXECUTABLE}")
|
||||
message(STATUS "Adding --offload-compress flag for ${PROFILER_EXECUTABLE}")
|
||||
target_compile_options(${PROFILER_EXECUTABLE} PRIVATE --offload-compress)
|
||||
endif()
|
||||
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE utility getopt::getopt)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_fwd_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_bwd_data_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_bwd_gamma_beta_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_softmax_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_reduce_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batchnorm_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_pool2d_fwd_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_pool3d_fwd_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_avg_pool2d_bwd_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_avg_pool3d_bwd_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_max_pool_bwd_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_image_to_column_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_column_to_image_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_transpose_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_permute_scale_instance)
|
||||
|
||||
set(DEVICE_INSTANCES "")
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_normalization_fwd_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_normalization_bwd_data_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_normalization_bwd_gamma_beta_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_softmax_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_reduce_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_batchnorm_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_pool2d_fwd_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_pool3d_fwd_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_avg_pool2d_bwd_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_avg_pool3d_bwd_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_max_pool_bwd_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_image_to_column_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_column_to_image_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_transpose_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_permute_scale_instance)
|
||||
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9")
|
||||
if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_bilinear_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_scale_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_contraction_bilinear_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_contraction_scale_instance)
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_add_fastgelu_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_fastgelu_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_gemm_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_add_relu_gemm_add_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_streamk_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_fastgelu_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_relu_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_silu_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_relu_add_layernorm_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_fixed_nk_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_fastgelu_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_tile_loop_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_add_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_add_add_fastgelu_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_fastgelu_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_batched_gemm_gemm_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_batched_gemm_add_relu_gemm_add_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_gemm_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_streamk_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_add_fastgelu_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_add_relu_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_add_silu_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_add_relu_add_layernorm_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_gemm_fixed_nk_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_gemm_fastgelu_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_gemm_tile_loop_instance)
|
||||
endif()
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_reduce_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_add_instance)
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx94")
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_multiply_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_multiply_wp_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_ab_scale_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_batched_gemm_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_batched_gemm_reduce_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_multiply_add_instance)
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9[45]")
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_multiply_multiply_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_multiply_multiply_wp_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_ab_scale_instance)
|
||||
endif()
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_splitk_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_b_scale_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_b_scale_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_universal_batched_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_universal_reduce_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_universal_streamk_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_multiply_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_reduce_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bias_add_reduce_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_bias_relu_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_bias_relu_add_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_fwd_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv1d_bwd_data_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv3d_bwd_data_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_bwd_data_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_bwd_weight_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_weight_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_fwd_convscale_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_fwd_convinvscale_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_splitk_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_b_scale_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_batched_gemm_b_scale_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_universal_batched_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_universal_reduce_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_universal_streamk_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_add_multiply_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_reduce_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_bias_add_reduce_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_conv2d_fwd_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_conv2d_fwd_bias_relu_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_conv2d_fwd_bias_relu_add_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv1d_fwd_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_conv1d_bwd_data_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_conv3d_bwd_data_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_conv2d_bwd_data_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv1d_bwd_weight_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv2d_bwd_weight_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_convscale_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_convinvscale_instance)
|
||||
endif()
|
||||
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12")
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bilinear_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_bilinear_instance)
|
||||
endif()
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_universal_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_fwd_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_data_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_data_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_fwd_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_weight_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_universal_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv2d_bwd_data_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_data_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_weight_instance)
|
||||
endif()
|
||||
|
||||
if(DL_KERNELS)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_multi_d_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_bwd_weight_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_weight_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_weight_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_batched_gemm_multi_d_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv1d_bwd_weight_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv2d_bwd_weight_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_weight_instance)
|
||||
endif()
|
||||
|
||||
set(PROFILER_LIBS utility getopt::getopt)
|
||||
foreach(LIB ${DEVICE_INSTANCES})
|
||||
string(REGEX REPLACE "device_(.+)_instance" "\\1" INSTANCE_NAME ${LIB})
|
||||
if (INSTANCE_NAME STREQUAL "")
|
||||
message(FATAL_ERROR "Unexpected kernel instance name: ${LIB}")
|
||||
endif()
|
||||
if("${INSTANCE_NAME}" MATCHES "${CK_PROFILER_INSTANCE_FILTER}")
|
||||
list(APPEND PROFILER_LIBS ${LIB})
|
||||
endif()
|
||||
endforeach()
|
||||
message(STATUS "ckProfiler libs: ${PROFILER_LIBS}")
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE ${PROFILER_LIBS})
|
||||
|
||||
rocm_install(TARGETS ${PROFILER_EXECUTABLE} COMPONENT profiler)
|
||||
|
||||
@@ -5,7 +5,7 @@ find_package(SQLite3 REQUIRED)
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py
|
||||
--working_path ${CMAKE_CURRENT_BINARY_DIR}
|
||||
--config_json ${CMAKE_CURRENT_LIST_DIR}/configs/user_provided_config.json
|
||||
# --config_json ${CMAKE_CURRENT_LIST_DIR}/configs/user_provided_config.json
|
||||
--list_blobs
|
||||
RESULT_VARIABLE ret
|
||||
)
|
||||
@@ -33,7 +33,7 @@ add_custom_command(
|
||||
OUTPUT ${GEMM_CODEGEN_BLOBS}
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py
|
||||
--working_path ${CMAKE_CURRENT_BINARY_DIR}
|
||||
--config_json ${CMAKE_CURRENT_LIST_DIR}/configs/user_provided_config.json
|
||||
# --config_json ${CMAKE_CURRENT_LIST_DIR}/configs/user_provided_config.json
|
||||
--gen_blobs
|
||||
)
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ CK Tile Engine GEMM is used to generate and run GEMM kernels with different comb
|
||||
|
||||
User can provide kernel configuration such as tile size, warp size, padding, pipeline, scheduler and epilogue in the config file with limited values. For reference please see `./configs/user_provided_config.json`.
|
||||
|
||||
The Tile engine also has a default kernel configuration for providing range of configuration parameter values, which helps users who lack kernel development experience to benchmark For reference please see in `./configs/default_config.json`
|
||||
The Tile engine also has a default kernel configuration for providing range of configuration parameter values, which helps users who lack kernel development experience to benchmark. For reference please see in `./configs/default_config.json`
|
||||
|
||||
If user does not provide kernel configuration, the tile engine uses default kernel configuration to generate kernel instances and benchmark.
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
#include <stdexcept>
|
||||
|
||||
#include "gemm_host_api.hpp"
|
||||
@@ -39,6 +40,13 @@ struct GemmProblem
|
||||
|
||||
bool structured_sparsity_;
|
||||
|
||||
std::string to_json() const
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << *this;
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, const GemmProblem& problem)
|
||||
{
|
||||
os << "{\n"
|
||||
|
||||
@@ -22,18 +22,18 @@ class GemmProfiler
|
||||
return instance;
|
||||
}
|
||||
|
||||
bool is_problem_record_cache(const GemmProblem& gemm_problem)
|
||||
bool if_should_profile(const GemmProblem& gemm_problem)
|
||||
{
|
||||
if(setting_.enable_profile_cache_)
|
||||
{
|
||||
if(!cache_db_->check_if_record_problem(
|
||||
get_rocm_version(), ck_tile::get_device_name(), gemm_problem))
|
||||
{
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
else
|
||||
{
|
||||
auto [name, perf_result] = cache_db_->query_cache(
|
||||
const auto& [name, perf_result] = cache_db_->query_cache(
|
||||
get_rocm_version(), ck_tile::get_device_name(), gemm_problem);
|
||||
KernelInstance kernel_instance;
|
||||
kernel_instance.problem_ = gemm_problem;
|
||||
@@ -41,16 +41,16 @@ class GemmProfiler
|
||||
kernel_instance.perf_result_.latency_ = perf_result.latency_;
|
||||
kernel_instance.perf_result_.tflops_ = perf_result.tflops_;
|
||||
kernel_instance.perf_result_.bandwidth_ = perf_result.bandwidth_;
|
||||
std::cout << "Skip this problem for " << gemm_problem
|
||||
std::cout << "Skip this instance for " << kernel_instance
|
||||
<< ", Because it has already been recorded in the cache database"
|
||||
<< std::endl;
|
||||
kernel_instances_.emplace_back(kernel_instance);
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -69,7 +69,7 @@ class GemmProfiler
|
||||
gemm_problem.stride_c_ = ck_tile::get_default_stride(
|
||||
gemm_problem.m_, gemm_problem.n_, gemm_problem.stride_c_, is_row_major(layout_c));
|
||||
|
||||
if(is_problem_record_cache(gemm_problem))
|
||||
if(!if_should_profile(gemm_problem))
|
||||
return;
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m_k(ck_tile::host_tensor_descriptor(
|
||||
@@ -157,7 +157,7 @@ class GemmProfiler
|
||||
gemm_problem.stride_c_);
|
||||
}
|
||||
|
||||
for(auto callable : callables)
|
||||
for(auto& callable : callables)
|
||||
{
|
||||
auto kernel_run_result = callable(gemm_args,
|
||||
ck_tile::stream_config{nullptr,
|
||||
@@ -186,7 +186,7 @@ class GemmProfiler
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
|
||||
const std::tuple<std::string, float>& kernel_run_result)
|
||||
{
|
||||
auto [name, avg_time] = kernel_run_result;
|
||||
const auto& [name, avg_time] = kernel_run_result;
|
||||
|
||||
KernelInstance kernel_instance{name, gemm_problem, {-1.0f, -1.0f, -1.0f}};
|
||||
|
||||
@@ -290,7 +290,7 @@ class GemmProfiler
|
||||
return kernel_instance;
|
||||
}
|
||||
|
||||
GemmProfiler(const GemmProfiler&) = delete;
|
||||
GemmProfiler(const GemmProfiler&) = delete;
|
||||
GemmProfiler& operator=(const GemmProfiler&) = delete;
|
||||
|
||||
private:
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
|
||||
#include <sqlite3.h>
|
||||
#include <tuple>
|
||||
#include <sstream>
|
||||
|
||||
#include "benchmark_gemm.hpp"
|
||||
|
||||
@@ -72,9 +71,9 @@ class ProfileCacheDB
|
||||
|
||||
try
|
||||
{
|
||||
exec_direct("PRAGMA journal_mode = WAL");
|
||||
exec_direct("PRAGMA synchronous = NORMAL");
|
||||
exec_direct("PRAGMA foreign_keys = ON");
|
||||
execute("PRAGMA journal_mode = WAL");
|
||||
execute("PRAGMA synchronous = NORMAL");
|
||||
execute("PRAGMA foreign_keys = ON");
|
||||
|
||||
constexpr const char* schema = R"sql(
|
||||
CREATE TABLE IF NOT EXISTS gemm (
|
||||
@@ -87,12 +86,9 @@ class ProfileCacheDB
|
||||
tflops REAL CHECK(tflops > 0),
|
||||
bandwidth REAL CHECK(bandwidth > 0)
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_latency ON gemm(latency);
|
||||
CREATE INDEX IF NOT EXISTS idx_tflops_desc ON gemm(tflops DESC);
|
||||
CREATE INDEX IF NOT EXISTS idx_bandwidth_desc ON gemm(bandwidth DESC);
|
||||
)sql";
|
||||
|
||||
exec_direct(schema);
|
||||
execute(schema);
|
||||
}
|
||||
catch(...)
|
||||
{
|
||||
@@ -105,12 +101,12 @@ class ProfileCacheDB
|
||||
const GemmProblem& gemm_problem)
|
||||
{
|
||||
constexpr const char* sql = R"sql(
|
||||
SELECT 1 FROM gemm
|
||||
WHERE rocm_version=?
|
||||
AND device_name=?
|
||||
AND problem=?
|
||||
LIMIT 1
|
||||
)sql";
|
||||
SELECT 1 FROM gemm
|
||||
WHERE rocm_version=?
|
||||
AND device_name=?
|
||||
AND problem=?
|
||||
LIMIT 1
|
||||
)sql";
|
||||
|
||||
StmtWrapper stmt(db_ptr_.get(), sql);
|
||||
sqlite3_stmt* raw_stmt = stmt;
|
||||
@@ -120,23 +116,15 @@ class ProfileCacheDB
|
||||
db_ptr_.get());
|
||||
CHECK_SQLITE3(sqlite3_bind_text(raw_stmt, idx++, device_name.c_str(), -1, SQLITE_TRANSIENT),
|
||||
db_ptr_.get());
|
||||
std::ostringstream oss;
|
||||
oss << gemm_problem;
|
||||
auto problem_json = oss.str();
|
||||
CHECK_SQLITE3(
|
||||
sqlite3_bind_text(raw_stmt, idx++, problem_json.c_str(), -1, SQLITE_TRANSIENT),
|
||||
db_ptr_.get());
|
||||
CHECK_SQLITE3(sqlite3_bind_text(
|
||||
raw_stmt, idx++, gemm_problem.to_json().c_str(), -1, SQLITE_TRANSIENT),
|
||||
db_ptr_.get());
|
||||
|
||||
int rc;
|
||||
CHECK_SQLITE3_RC(sqlite3_step(raw_stmt), db_ptr_.get(), rc);
|
||||
CHECK_SQLITE3(sqlite3_reset(raw_stmt), db_ptr_.get());
|
||||
CHECK_SQLITE3(sqlite3_clear_bindings(raw_stmt), db_ptr_.get());
|
||||
|
||||
std::cout << "Query params:\n"
|
||||
<< "rocm_version: " << rocm_version << "\n"
|
||||
<< "device_name: " << device_name << "\n"
|
||||
<< "problem_json: " << problem_json << std::endl;
|
||||
|
||||
if(rc == SQLITE_DONE)
|
||||
{
|
||||
std::cout << "No matching records found" << std::endl;
|
||||
@@ -149,8 +137,9 @@ class ProfileCacheDB
|
||||
const GemmProblem& gemm_problem)
|
||||
{
|
||||
constexpr const char* sql = R"sql(
|
||||
SELECT latency, tflops, bandwidth FROM gemm
|
||||
WHERE rocm_version=? AND device_name=?
|
||||
SELECT instance_name, latency, tflops, bandwidth FROM gemm
|
||||
WHERE rocm_version=?
|
||||
AND device_name=?
|
||||
AND problem=?
|
||||
LIMIT 1
|
||||
)sql";
|
||||
@@ -164,12 +153,9 @@ class ProfileCacheDB
|
||||
db_ptr_.get());
|
||||
CHECK_SQLITE3(sqlite3_bind_text(raw_stmt, idx++, device_name.c_str(), -1, SQLITE_TRANSIENT),
|
||||
db_ptr_.get());
|
||||
std::ostringstream oss;
|
||||
oss << gemm_problem;
|
||||
auto problem_json = oss.str();
|
||||
CHECK_SQLITE3(sqlite3_bind_text(
|
||||
stmt, idx++, problem_json.c_str(), problem_json.size(), SQLITE_TRANSIENT),
|
||||
db_ptr_.get());
|
||||
CHECK_SQLITE3(
|
||||
sqlite3_bind_text(stmt, idx++, gemm_problem.to_json().c_str(), -1, SQLITE_TRANSIENT),
|
||||
db_ptr_.get());
|
||||
|
||||
int rc;
|
||||
CHECK_SQLITE3_RC(sqlite3_step(raw_stmt), db_ptr_.get(), rc);
|
||||
@@ -195,15 +181,25 @@ class ProfileCacheDB
|
||||
const std::string& device_name,
|
||||
const std::vector<KernelInstance>& kernen_instnaces)
|
||||
{
|
||||
exec_direct("BEGIN TRANSACTION");
|
||||
execute("BEGIN TRANSACTION");
|
||||
try
|
||||
{
|
||||
constexpr const char* sql = R"sql(
|
||||
INSERT INTO gemm
|
||||
(rocm_version, device_name,
|
||||
problem, instance_name,
|
||||
latency, tflops, bandwidth)
|
||||
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)
|
||||
(rocm_version,
|
||||
device_name,
|
||||
problem,
|
||||
instance_name,
|
||||
latency,
|
||||
tflops,
|
||||
bandwidth)
|
||||
VALUES (?1,
|
||||
?2,
|
||||
?3,
|
||||
?4,
|
||||
?5,
|
||||
?6,
|
||||
?7)
|
||||
)sql";
|
||||
|
||||
StmtWrapper stmt(db_ptr_.get(), sql);
|
||||
@@ -218,15 +214,10 @@ class ProfileCacheDB
|
||||
CHECK_SQLITE3(
|
||||
sqlite3_bind_text(raw_stmt, idx++, device_name.c_str(), -1, SQLITE_TRANSIENT),
|
||||
db_ptr_.get());
|
||||
std::ostringstream oss;
|
||||
oss << item.problem_;
|
||||
auto problem_json = oss.str();
|
||||
CHECK_SQLITE3(sqlite3_bind_text(raw_stmt,
|
||||
idx++,
|
||||
problem_json.c_str(),
|
||||
problem_json.size(),
|
||||
SQLITE_TRANSIENT),
|
||||
db_ptr_.get());
|
||||
CHECK_SQLITE3(
|
||||
sqlite3_bind_text(
|
||||
raw_stmt, idx++, item.problem_.to_json().c_str(), -1, SQLITE_TRANSIENT),
|
||||
db_ptr_.get());
|
||||
CHECK_SQLITE3(
|
||||
sqlite3_bind_text(raw_stmt, idx++, item.name_.c_str(), -1, SQLITE_TRANSIENT),
|
||||
db_ptr_.get());
|
||||
@@ -242,17 +233,17 @@ class ProfileCacheDB
|
||||
CHECK_SQLITE3(sqlite3_reset(raw_stmt), db_ptr_.get());
|
||||
CHECK_SQLITE3(sqlite3_clear_bindings(raw_stmt), db_ptr_.get());
|
||||
}
|
||||
exec_direct("COMMIT");
|
||||
execute("COMMIT");
|
||||
}
|
||||
catch(...)
|
||||
{
|
||||
exec_direct("ROLLBACK");
|
||||
execute("ROLLBACK");
|
||||
throw;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
void exec_direct(const char* sql)
|
||||
void execute(const char* sql)
|
||||
{
|
||||
CHECK_SQLITE3(sqlite3_exec(db_ptr_.get(), sql, nullptr, nullptr, nullptr), db_ptr_.get());
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user