mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-27 08:25:46 +00:00
bug fix, clang format;
This commit is contained in:
@@ -21,7 +21,7 @@
|
||||
|
||||
struct AlphaBetaAdd
|
||||
{
|
||||
AlphaBetaAdd(float alpha, float beta) : alpha_(alpha), beta_(beta){};
|
||||
AlphaBetaAdd(float alpha, float beta) : alpha_(alpha), beta_(beta) {};
|
||||
|
||||
template <typename E, typename C, typename D>
|
||||
__host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const;
|
||||
|
||||
@@ -21,7 +21,7 @@
|
||||
|
||||
struct AlphaBetaAdd
|
||||
{
|
||||
AlphaBetaAdd(int alpha, int beta) : alpha_(alpha), beta_(beta){};
|
||||
AlphaBetaAdd(int alpha, int beta) : alpha_(alpha), beta_(beta) {};
|
||||
|
||||
template <typename E, typename C, typename D>
|
||||
__host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const;
|
||||
|
||||
@@ -20,7 +20,7 @@
|
||||
|
||||
struct AlphaBetaAdd
|
||||
{
|
||||
AlphaBetaAdd(float alpha, float beta) : alpha_(alpha), beta_(beta){};
|
||||
AlphaBetaAdd(float alpha, float beta) : alpha_(alpha), beta_(beta) {};
|
||||
|
||||
template <typename E, typename C, typename D>
|
||||
__host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const;
|
||||
|
||||
@@ -83,7 +83,7 @@ struct AddScale
|
||||
|
||||
struct AlphaBetaAdd
|
||||
{
|
||||
AlphaBetaAdd(float alpha, float beta) : alpha_(alpha), beta_(beta){};
|
||||
AlphaBetaAdd(float alpha, float beta) : alpha_(alpha), beta_(beta) {};
|
||||
|
||||
template <typename E, typename C, typename D>
|
||||
__host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const;
|
||||
|
||||
@@ -42,7 +42,7 @@ static constexpr ck::index_t NumDimK = 2;
|
||||
|
||||
struct AlphaBetaAdd
|
||||
{
|
||||
AlphaBetaAdd(float alpha, float beta) : alpha_(alpha), beta_(beta){};
|
||||
AlphaBetaAdd(float alpha, float beta) : alpha_(alpha), beta_(beta) {};
|
||||
|
||||
template <typename E, typename C, typename D>
|
||||
__host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const;
|
||||
|
||||
@@ -29,7 +29,7 @@ set(FMHA_FWD_CODE_GEN_COMMON_ARGS
|
||||
${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
--api ${FMHA_FWD_APIS}
|
||||
--optdim 32,64,128,256
|
||||
# --filter fmha_fwd...
|
||||
--filter fmha_fwd_d128_bf16_batch_b128x64x32x128x16x128_r4x1x1_r4x1x1_w32x32x16_w32x32x16_qr_async_trload_vr_npad_nlogits_nbias_nmask_nlse_ndropout_nskip_nsquant_trload
|
||||
)
|
||||
set(FMHA_BWD_CODE_GEN_COMMON_ARGS
|
||||
${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
|
||||
@@ -644,6 +644,8 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
|
||||
# non qr_async_trload only support km0=128 tile size when hdim is not 128
|
||||
# non qr_async only support kn0=128 tile size when hdim is 128
|
||||
continue
|
||||
if pipeline.tag == 'qr_async_trload' and (((hdim, hdim_v) == (128, 128) and tile.F_bn0 == 128) or ((hdim, hdim_v) not in [(64, 64), (128, 128)])):
|
||||
continue
|
||||
# logits_soft_cap is only allowed if no bias
|
||||
if not ((pipeline.F_logits == 't' and pipeline.F_bias == 'no') or pipeline.F_logits == 'f'):
|
||||
continue
|
||||
|
||||
Reference in New Issue
Block a user