diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp index 84883d6ed8..beab457b90 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp @@ -35,16 +35,13 @@ struct BlockGemmARegBSmemCRegV1 std::is_same_v>, "wrong!"); - // constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}]; - // constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}]; - // constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}]; - constexpr index_t MPerBlock = BlockGemmShape::kM; - constexpr index_t NPerBlock = BlockGemmShape::kN; - constexpr index_t KPerBlock = BlockGemmShape::kK; + constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}]; + constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}]; + constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}]; - // static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && - // KPerBlock == BlockGemmShape::kK, - // "wrong!"); + static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && + KPerBlock == BlockGemmShape::kK, + "wrong!"); constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); diff --git a/include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp b/include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp index 65ce1a9b8f..3d142df4d4 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp @@ -35,16 +35,13 @@ struct BlockGemmASmemBRegCRegV1 std::is_same_v>, "wrong!"); - // constexpr index_t MPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<0>{}]; - // constexpr index_t NPerBlock = BBlockTensorTmp{}.get_lengths()[number<0>{}]; - // constexpr index_t KPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<1>{}]; - constexpr index_t MPerBlock = BlockGemmShape::kM; - constexpr index_t NPerBlock = BlockGemmShape::kN; - constexpr index_t KPerBlock = BlockGemmShape::kK; + constexpr index_t MPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<0>{}]; + constexpr index_t NPerBlock = BBlockTensorTmp{}.get_lengths()[number<0>{}]; + constexpr index_t KPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<1>{}]; - // static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && - // KPerBlock == BlockGemmShape::kK, - // "wrong!"); + static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && + KPerBlock == BlockGemmShape::kK, + "wrong!"); constexpr auto config = Policy::template GetWarpGemmMWarpNWarp();