mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
Re-format the .hpp/.cpp files using clang-format-18
This commit is contained in:
@@ -564,9 +564,9 @@ struct HstuAttentionFwdKernel
|
||||
const index_t i_nhead = blockIdx.y;
|
||||
const index_t i_block = blockIdx.z;
|
||||
#else
|
||||
const index_t i_block = blockIdx.x;
|
||||
const index_t i_nhead = blockIdx.y;
|
||||
const index_t i_batch = blockIdx.z;
|
||||
const index_t i_block = blockIdx.x;
|
||||
const index_t i_nhead = blockIdx.y;
|
||||
const index_t i_batch = blockIdx.z;
|
||||
#endif
|
||||
|
||||
const auto f = [](index_t dividend, index_t divisor) {
|
||||
@@ -591,16 +591,16 @@ struct HstuAttentionFwdKernel
|
||||
const index_t i_nhead = blockIdx.y;
|
||||
const index_t i_block = blockIdx.z;
|
||||
#else
|
||||
const index_t i_block = blockIdx.x;
|
||||
const index_t i_nhead = blockIdx.y;
|
||||
const index_t i_batch = blockIdx.z;
|
||||
const index_t i_block = blockIdx.x;
|
||||
const index_t i_nhead = blockIdx.y;
|
||||
const index_t i_batch = blockIdx.z;
|
||||
#endif
|
||||
|
||||
#if HSTU_SCHED_BATCH_AS_FIRST_GRID_DIM
|
||||
index_t i_tile_m = i_block;
|
||||
i_tile_m = gridDim.z - 1 - i_tile_m;
|
||||
#else
|
||||
const index_t i_tile_m = i_block;
|
||||
const index_t i_tile_m = i_block;
|
||||
#endif
|
||||
const index_t i_tile_n = 0;
|
||||
|
||||
|
||||
@@ -582,8 +582,8 @@ struct HstuAttentionFwdPipelineQRKSVSPolicy
|
||||
|
||||
auto warp_gemm = [&]() {
|
||||
if constexpr((std::is_same_v<typename Problem::QKVDataType, half_t> ||
|
||||
std::is_same_v<typename Problem::QKVDataType, bf16_t>)&&std::
|
||||
is_same_v<typename Problem::GemmAccDataType, float>)
|
||||
std::is_same_v<typename Problem::QKVDataType, bf16_t>) &&
|
||||
std::is_same_v<typename Problem::GemmAccDataType, float>)
|
||||
{
|
||||
constexpr index_t WarpGemmM =
|
||||
Problem::HstuAttentionTileSetting::Gemm0WarpTile::at(number<0>{});
|
||||
@@ -654,8 +654,8 @@ struct HstuAttentionFwdPipelineQRKSVSPolicy
|
||||
|
||||
auto warp_gemm = [&]() {
|
||||
if constexpr((std::is_same_v<typename Problem::QKVDataType, half_t> ||
|
||||
std::is_same_v<typename Problem::QKVDataType, bf16_t>)&&std::
|
||||
is_same_v<typename Problem::GemmAccDataType, float>)
|
||||
std::is_same_v<typename Problem::QKVDataType, bf16_t>) &&
|
||||
std::is_same_v<typename Problem::GemmAccDataType, float>)
|
||||
{
|
||||
constexpr index_t WarpGemmM =
|
||||
Problem::HstuAttentionTileSetting::Gemm1WarpTile::at(number<0>{});
|
||||
|
||||
@@ -562,9 +562,9 @@ struct HstuAttentionFwdSplitKVKernel
|
||||
const index_t i_nhead = blockIdx.y;
|
||||
const index_t i_block = blockIdx.z;
|
||||
#else
|
||||
const index_t i_block = blockIdx.x;
|
||||
const index_t i_nhead = blockIdx.y;
|
||||
const index_t i_batch = blockIdx.z;
|
||||
const index_t i_block = blockIdx.x;
|
||||
const index_t i_nhead = blockIdx.y;
|
||||
const index_t i_batch = blockIdx.z;
|
||||
#endif
|
||||
|
||||
#if HSTU_SCHED_BATCH_AS_FIRST_GRID_DIM
|
||||
@@ -586,9 +586,9 @@ struct HstuAttentionFwdSplitKVKernel
|
||||
const index_t i_nhead = blockIdx.y;
|
||||
const index_t i_block = blockIdx.z;
|
||||
#else
|
||||
const index_t i_block = blockIdx.x;
|
||||
const index_t i_nhead = blockIdx.y;
|
||||
const index_t i_batch = blockIdx.z;
|
||||
const index_t i_block = blockIdx.x;
|
||||
const index_t i_nhead = blockIdx.y;
|
||||
const index_t i_batch = blockIdx.z;
|
||||
#endif
|
||||
|
||||
#if HSTU_SCHED_BATCH_AS_FIRST_GRID_DIM
|
||||
@@ -596,8 +596,8 @@ struct HstuAttentionFwdSplitKVKernel
|
||||
auto [i_tile_m, i_split] = f(i_tile_m_i_split, kargs.num_splits);
|
||||
i_tile_m = gridDim.z / kargs.num_splits - 1 - i_tile_m;
|
||||
#else
|
||||
index_t i_tile_m_i_split = i_block;
|
||||
auto [i_tile_m, i_split] = f(i_tile_m_i_split, kargs.num_splits);
|
||||
index_t i_tile_m_i_split = i_block;
|
||||
auto [i_tile_m, i_split] = f(i_tile_m_i_split, kargs.num_splits);
|
||||
#endif
|
||||
const index_t i_tile_n = 0;
|
||||
|
||||
|
||||
@@ -120,7 +120,7 @@ CK_TILE_DEVICE static void scale_tile_in_pack(InOutDstrTensor& in_out_dstr_tenso
|
||||
|
||||
static_for<0, thread_buffer_size, 2>{}([&](auto idx) {
|
||||
fp32x2_t input = {in_out_dstr_tensor.thread_buf_[idx],
|
||||
in_out_dstr_tensor.thread_buf_[idx + 1]};
|
||||
in_out_dstr_tensor.thread_buf_[idx + 1]};
|
||||
auto output = pk_mul_f32(input, pk_scale);
|
||||
in_out_dstr_tensor.thread_buf_[idx] = output.x;
|
||||
in_out_dstr_tensor.thread_buf_[idx + 1] = output.y;
|
||||
|
||||
Reference in New Issue
Block a user