Support Wave32 in CK_TILE - Part 1 (#2594)

* Support wave32/wave64 in CK_TILE - Part 1

* remove blocksize in kernel launch

* fix build error

* fix clang format

* fix clang format 2

* fix clang format 3

* fix fmha build error

* fix fmha build 2

* fix fmha build 3

* fix build error 4

* address review comment

* update change log

* replace KernelBlockSize with kBlockSize

* fix CI fail

* fix clang format

* address review comment and rebase code.

* fix universal test fail

---------

Co-authored-by: Lin, Qun <Quentin.Lin+amdeng@amd.com>
Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
This commit is contained in:
linqunAMD
2025-08-19 01:08:31 +08:00
committed by GitHub
parent 26d3300930
commit 9fcc1ee9fd
113 changed files with 610 additions and 531 deletions

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -117,7 +117,7 @@ struct naive_attention_fwd_kernel
std::is_same_v<KType, fp8_t> && std::is_same_v<VType, fp8_t>;
static constexpr int v_per_token_quant_group_size = 64;
static constexpr int kBlockSize = 256;
// TODO: hardcode
using SoftmaxType = float; // always using float to do softmax compute
using QuantComputeType = float; // used for quant/dequant scale compute
@@ -254,7 +254,7 @@ struct naive_attention_fwd_kernel
__device__ T load(int i_s, int i_h, int i_d) { return base_ptr[get_offset(i_s, i_h, i_d)]; }
};
__device__ __host__ static constexpr int get_block_size() { return 256; }
__device__ __host__ static constexpr int get_block_size() { return kBlockSize; }
// for simpliciy, 1 WG always compute 1 token along q, compute all token along kv
// compute all hdim from q, compute WG_SIZE hdim from v