mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 03:07:02 +00:00
[CK] suppress compiler warnings while building pytorch. (#7760) ## Motivation Recently added compiler flags that are required to suppress false warnings by latest staging compiler are not recognized by older compiler versions and are triggering an avalanche of warnings. Previous attempt to suppress them by using -Wno-unknown-warning-option flag didn't help, because that flag wasn't recognized either and just added more warnings. I've verified that current approach by checking the clang version actually works as intended and makes the warnings go away. ## Technical Details <!-- Explain the changes along with any relevant GitHub links. --> ## Test Plan <!-- Explain any relevant testing done to verify this PR. --> ## Test Result <!-- Briefly summarize test outcomes. --> ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
181 lines
6.1 KiB
C++
181 lines
6.1 KiB
C++
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
// SPDX-License-Identifier: MIT
|
|
|
|
#pragma once
|
|
|
|
#include <iostream>
|
|
#include <string>
|
|
|
|
#include "ck_tile/core.hpp"
|
|
#include "ck_tile/ops/common.hpp"
|
|
#include "ck_tile/host/concat.hpp"
|
|
#include "ck_tile/host/kernel_launch.hpp"
|
|
#include "ck_tile/host/stream_utils.hpp"
|
|
#include "ck_tile/core/utility/env.hpp"
|
|
#include "ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp"
|
|
#include "ck_tile/core/utility/type_traits.hpp"
|
|
|
|
#if __clang_major__ >= 23
|
|
#pragma clang diagnostic push
|
|
#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions"
|
|
#endif
|
|
namespace ck_tile {
|
|
|
|
/// @brief The GEMM kernel host arguments.
|
|
///
|
|
/// @par Overview
|
|
/// This structure is passed to @ref GemmKernel "GemmKernel" when creating kernel arguments
|
|
/// object. It contain all necessary information required to build proper kernel argument
|
|
/// and launch kernel on GPU.
|
|
/// This structure defines the GEMM problem configuration by stating all required information
|
|
/// like M,N,K sizes and respective strides.
|
|
struct GemmHostArgs
|
|
{
|
|
CK_TILE_HOST GemmHostArgs() = default;
|
|
CK_TILE_HOST GemmHostArgs(const void* a_ptr_,
|
|
const void* b_ptr_,
|
|
void* e_ptr_,
|
|
index_t k_batch_,
|
|
index_t M_,
|
|
index_t N_,
|
|
index_t K_,
|
|
index_t stride_A_,
|
|
index_t stride_B_,
|
|
index_t stride_E_)
|
|
: a_ptr(a_ptr_),
|
|
b_ptr(b_ptr_),
|
|
e_ptr(e_ptr_),
|
|
M(M_),
|
|
N(N_),
|
|
K(K_),
|
|
stride_A(stride_A_),
|
|
stride_B(stride_B_),
|
|
stride_E(stride_E_),
|
|
k_batch(k_batch_)
|
|
{
|
|
}
|
|
|
|
const void* a_ptr;
|
|
const void* b_ptr;
|
|
union
|
|
{
|
|
void* e_ptr;
|
|
void* c_ptr;
|
|
};
|
|
|
|
index_t M;
|
|
index_t N;
|
|
index_t K;
|
|
index_t stride_A;
|
|
index_t stride_B;
|
|
|
|
union
|
|
{
|
|
index_t stride_E;
|
|
index_t stride_C;
|
|
};
|
|
|
|
index_t k_batch;
|
|
};
|
|
|
|
template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
|
|
struct GemmKernel
|
|
{
|
|
/// @brief Inject the UniversalGemmKernel base class to support execution of all necessary
|
|
/// functions.
|
|
using UniversalGemmKernel =
|
|
UniversalGemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_>;
|
|
|
|
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
|
|
using GemmPipeline = remove_cvref_t<GemmPipeline_>;
|
|
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
|
|
|
|
/// @brief Specify the layout configurations for A, B, E and D
|
|
using ALayout = remove_cvref_t<typename GemmPipeline::ALayout>;
|
|
using BLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
|
|
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
|
|
|
|
/// @brief Specify the data type configurations for A, B, E and D
|
|
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
|
|
using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
|
|
using EDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
|
|
|
|
/// @brief ALayout and ADataType are expected to be scalars, not a tuple.
|
|
static_assert(
|
|
!is_detected<is_tuple, ALayout>::value && !is_detected<is_tuple, ADataType>::value,
|
|
"ALayout and ADataType must be scalars. Multiple parameters are not currently supported.");
|
|
|
|
/// @brief BLayout and BDataType are expected to be scalars, not a tuple.
|
|
static_assert(
|
|
!is_detected<is_tuple, BLayout>::value && !is_detected<is_tuple, BDataType>::value,
|
|
"BLayout and BDataType must be scalars. Multiple parameters are not currently supported.");
|
|
|
|
/// @brief C/CLayout and C/EDataType are expected to be scalars, not a tuple.
|
|
static_assert(!is_detected<is_tuple, CLayout>::value &&
|
|
!is_detected<is_tuple, EDataType>::value,
|
|
"C/CLayout and C/EDataType must be scalars.");
|
|
|
|
static constexpr index_t NumATensor = 1;
|
|
static constexpr index_t NumBTensor = 1;
|
|
static constexpr index_t kBlockSize = UniversalGemmKernel::kBlockSize;
|
|
|
|
CK_TILE_HOST static auto GetName() -> const std::string
|
|
{
|
|
return UniversalGemmKernel::GetName();
|
|
}
|
|
|
|
CK_TILE_HOST static constexpr auto ClusterSize() { return UniversalGemmKernel::ClusterSize(); }
|
|
|
|
CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch) -> dim3
|
|
{
|
|
return UniversalGemmKernel::GridSize(M, N, KBatch);
|
|
}
|
|
|
|
CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3
|
|
{
|
|
return UniversalGemmKernel::MaxOccupancyGridSize(s);
|
|
}
|
|
|
|
CK_TILE_HOST static constexpr auto BlockSize() -> dim3
|
|
{
|
|
return UniversalGemmKernel::BlockSize();
|
|
}
|
|
|
|
CK_TILE_HOST static constexpr auto MakeKernelArgs(const GemmHostArgs& hostArgs) ->
|
|
typename UniversalGemmKernel::KernelArgs
|
|
{
|
|
/// @brief Universal GEMM requires array objects and corresponding stride information for
|
|
/// matrices A, B.
|
|
return UniversalGemmKernel::MakeKernelArgs(
|
|
UniversalGemmHostArgs<NumATensor, NumBTensor /*NumDTensor = 0 */>(
|
|
{hostArgs.a_ptr},
|
|
{hostArgs.b_ptr},
|
|
{/*hostArgs.ds_ptr*/},
|
|
hostArgs.e_ptr,
|
|
hostArgs.k_batch,
|
|
hostArgs.M,
|
|
hostArgs.N,
|
|
hostArgs.K,
|
|
{hostArgs.stride_A},
|
|
{hostArgs.stride_B},
|
|
{/*hostArgs.stride_Ds*/},
|
|
hostArgs.stride_E));
|
|
}
|
|
|
|
CK_TILE_HOST static auto
|
|
IsSupportedArgument(const typename UniversalGemmKernel::KernelArgs& kargs) -> bool
|
|
{
|
|
return UniversalGemmKernel::IsSupportedArgument(kargs);
|
|
}
|
|
|
|
CK_TILE_DEVICE auto operator()(typename UniversalGemmKernel::KernelArgs kargs) const -> void
|
|
{
|
|
UniversalGemmKernel{}.template operator()(kargs);
|
|
}
|
|
};
|
|
} // namespace ck_tile
|
|
|
|
#if __clang_major__ >= 23
|
|
#pragma clang diagnostic pop
|
|
#endif
|