diff --git a/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp b/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp index cfbd78967f..7d665f5428 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp @@ -57,8 +57,8 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy using WG = WarpGemmMfmaDispatcher && std::is_same_v && std::is_same_v) @@ -83,6 +84,7 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy wg_attr_num_access>; return make_tuple(WG{}, 4, 1); } +#endif else { static_assert(false, "Unsupported data type configuration for GEMM warp execution."); diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp index 6627bee87c..7eb665f59c 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp @@ -84,6 +84,18 @@ using WarpGemmWmmaF16F16F32M16N16K16TransposedCDistribution = WarpGemmImpl>>; +#endif + +using WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution = + WarpGemmImpl>>; + +using WarpGemmWmmaF16F16F32M16N16K16TransposedCDistribution = + WarpGemmImpl>>; + +#if 0 + #if defined(__gfx950__) template using WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution = diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_generic_impl_F16F16F32M16N16K16.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_generic_impl_F16F16F32M16N16K16.hpp index e7620260a9..f8765782dc 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_generic_impl_F16F16F32M16N16K16.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_generic_impl_F16F16F32M16N16K16.hpp @@ -130,3 +130,5 @@ struct WarpGemmAttributeGenericImplF16F16F32M16N16K16 #endif } }; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp index 97e4916cd8..ac4e081105 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp @@ -34,6 +34,7 @@ struct WarpGemmWmmaDispatcher; // clang-format off // fp16 // ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity +#if 0 template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K8; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K16<>; }; @@ -42,10 +43,14 @@ template<> struct WarpGemmMfmaDispatcher; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution; }; +#endif + template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M16N16K16; }; template<> struct WarpGemmWmmaDispatcher { using Type = WarpGemmWmmaF16F16F32M16N16K16; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution; }; template<> struct WarpGemmWmmaDispatcher { using Type = WarpGemmWmmaF16F16F32M16N16K16TransposedCDistribution; }; + +#if 0 template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M16N16K32<>; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution<>; }; template<> struct WarpGemmMfmaDispatcher { @@ -128,6 +133,7 @@ template<> struct WarpGemmMfmaDispatcher struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_i32_32x32x16_i8_i8_CTransposed; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_i32_16x16x32_i8_i8; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_i32_16x16x32_i8_i8_CTransposed; }; +#endif // clang-format on } // namespace impl