mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-02 21:27:45 +00:00
Concept bug fixes.
This commit is contained in:
@@ -157,36 +157,36 @@ concept SpecifiesTileThreadBlock = requires {
|
||||
|
||||
// Concept to check if a struct specifies gridwise XDL GEMM info.
|
||||
template <typename T>
|
||||
concept GridwiseFwdXdlGemmDescriptor = requires {
|
||||
{ T::ak1 } -> std::convertible_to<size_t>;
|
||||
{ T::bk1 } -> std::convertible_to<size_t>;
|
||||
{ T::xdl_params } -> GridwiseXdlGemmDescriptor;
|
||||
concept GridwiseFwdXdlGemmDescriptor = requires (T t){
|
||||
{ t.ak1 } -> std::convertible_to<size_t>;
|
||||
{ t.bk1 } -> std::convertible_to<size_t>;
|
||||
{ t.xdl_params } -> GridwiseXdlGemmDescriptor;
|
||||
};
|
||||
|
||||
// Concept to check if a struct specifies gridwise XDL GEMM info.
|
||||
template <typename T>
|
||||
concept GridwiseBwdXdlGemmDescriptor = requires {
|
||||
{ T::k0_per_block } -> std::convertible_to<size_t>;
|
||||
{ T::k1 } -> std::convertible_to<size_t>;
|
||||
{ T::xdl_params } -> GridwiseXdlGemmDescriptor;
|
||||
concept GridwiseBwdXdlGemmDescriptor = requires (T t){
|
||||
{ t.k0_per_block } -> std::convertible_to<size_t>;
|
||||
{ t.k1 } -> std::convertible_to<size_t>;
|
||||
{ t.xdl_params } -> GridwiseXdlGemmDescriptor;
|
||||
};
|
||||
|
||||
// Concept to check if a struct specifies gridwise XDL GEMM info.
|
||||
template <typename T>
|
||||
concept SpecifiesGridwiseFwdXdlGemm = requires {
|
||||
{ T::gridwise_gemm } -> GridwiseFwdXdlGemmDescriptor;
|
||||
concept SpecifiesGridwiseFwdXdlGemm = requires (T t) {
|
||||
{ t.gridwise_gemm } -> GridwiseFwdXdlGemmDescriptor;
|
||||
};
|
||||
|
||||
// Concept to check if a struct specifies gridwise XDL GEMM info.
|
||||
template <typename T>
|
||||
concept SpecifiesGridwiseBwdXdlGemm = requires {
|
||||
{ T::gridwise_gemm } -> GridwiseFwdXdlGemmDescriptor;
|
||||
concept SpecifiesGridwiseBwdXdlGemm = requires (T t) {
|
||||
{ t.gridwise_gemm } -> GridwiseBwdXdlGemmDescriptor;
|
||||
};
|
||||
|
||||
// Concept to check if a struct specifies gridwise WMMA GEMM info.
|
||||
template <typename T>
|
||||
concept SpecifiesGridwiseWmmaGemm = requires {
|
||||
{ T::gridwise_gemm } -> GridwiseBwdXdlGemmDescriptor;
|
||||
concept SpecifiesGridwiseWmmaGemm = requires (T t){
|
||||
{ t.gridwise_gemm } -> GridwiseWmmaGemmDescriptor;
|
||||
};
|
||||
|
||||
// Concept to check if a struct specifies convolution input and output block transfer info.
|
||||
|
||||
@@ -161,7 +161,7 @@ consteval ck::tensor_operation::device::ConvolutionForwardSpecialization SetFwdC
|
||||
template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
consteval ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization SetBwdWeightConvSpecialization()
|
||||
{
|
||||
constexpr auto specialization = ALGORITHM.bwd_specialization;
|
||||
constexpr auto specialization = ALGORITHM.bwd_weight_specialization;
|
||||
using ck_conv_spec = ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization;
|
||||
switch(specialization)
|
||||
{
|
||||
|
||||
@@ -25,7 +25,8 @@ constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CSh
|
||||
.with_gemm_config(cku::BwdGemmParams_Xdl_4x4_per_wave)
|
||||
.with_transfer(cku::Transfer_4x64x1)
|
||||
.with_bwd_specialization(ckb::ConvSpecialization::DEFAULT);
|
||||
|
||||
static_assert(cku::SpecifiesGridwiseBwdXdlGemm<decltype(ALGORITHM)>, "Error");
|
||||
|
||||
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
|
||||
using Instance = Builder::Instance;
|
||||
|
||||
|
||||
@@ -211,7 +211,7 @@ struct ConvSpecializationFwd_
|
||||
|
||||
struct ConvSpecializationBwdWeight_
|
||||
{
|
||||
ConvSpecialization bwd_specialization;
|
||||
ConvSpecialization bwd_weight_specialization;
|
||||
};
|
||||
|
||||
struct Prefetch_
|
||||
@@ -400,7 +400,7 @@ struct ConvAlgorithmTemplate : Components...
|
||||
{
|
||||
static_assert(std::is_base_of_v<ConvSpecializationBwdWeight_, ConvAlgorithmTemplate>);
|
||||
auto result = *this;
|
||||
result.bwd_specialization = bwd_spec;
|
||||
result.bwd_weight_specialization = bwd_spec;
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
@@ -278,7 +278,7 @@ template <>
|
||||
inline std::string to_string<ConvSpecializationBwdWeight_>(ConvSpecializationBwdWeight_ t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << to_string(t.bwd_specialization);
|
||||
oss << to_string(t.bwd_weight_specialization);
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user