mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-24 14:54:47 +00:00
Merge commit '9fcc1ee9fd9730efd865f530afde505f2556954d' into develop
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -117,7 +117,7 @@ struct naive_attention_fwd_kernel
|
||||
std::is_same_v<KType, fp8_t> && std::is_same_v<VType, fp8_t>;
|
||||
|
||||
static constexpr int v_per_token_quant_group_size = 64;
|
||||
|
||||
static constexpr int kBlockSize = 256;
|
||||
// TODO: hardcode
|
||||
using SoftmaxType = float; // always using float to do softmax compute
|
||||
using QuantComputeType = float; // used for quant/dequant scale compute
|
||||
@@ -254,7 +254,7 @@ struct naive_attention_fwd_kernel
|
||||
__device__ T load(int i_s, int i_h, int i_d) { return base_ptr[get_offset(i_s, i_h, i_d)]; }
|
||||
};
|
||||
|
||||
__device__ __host__ static constexpr int get_block_size() { return 256; }
|
||||
__device__ __host__ static constexpr int get_block_size() { return kBlockSize; }
|
||||
|
||||
// for simpliciy, 1 WG always compute 1 token along q, compute all token along kv
|
||||
// compute all hdim from q, compute WG_SIZE hdim from v
|
||||
|
||||
Reference in New Issue
Block a user