diff --git a/experimental/grouped_convolution_tile_instances/generate_instances.py b/experimental/grouped_convolution_tile_instances/generate_instances.py index 5346f6d2cb..b2cb37c39b 100755 --- a/experimental/grouped_convolution_tile_instances/generate_instances.py +++ b/experimental/grouped_convolution_tile_instances/generate_instances.py @@ -317,7 +317,7 @@ def parse_bwd_weight_instances(instances, problem_name): gemm_params = device_op_name = instance.split("<")[2].split(">")[1].split(",") args = [param.split(":")[1].strip() for param in gemm_params] - spec = "Default" + spec = "Filter1x1Stride1Pad0" block_size = int(args[0]) mnk_per_block = args[1].split("x") @@ -450,6 +450,13 @@ def parse_bwd_weight_instances(instances, problem_name): if pipeline_version == "V6": print(f"Skipping instance {instance_id} with V6 since it's not supported yet.") continue + if m_per_block > (warp_size * a_scalar_per_vector) or n_per_block > (warp_size * b_scalar_per_vector): + print(f"Skipping instance {instance_id} with multiple warps per continous tile dim since it's not supported yet.") + continue + + if is_explicit_gemm: + if dtype != "float" and c_scalar_per_vector % 2 != 0: + is_two_stage_instance = True conv = ConvInstanceTemplateParams( spec, diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp index 918eb3de26..a3268fa2ff 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp @@ -449,6 +449,14 @@ struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1 CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution() { - constexpr index_t BlockSize = Problem::kBlockSize; - constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t VecLoadSize = - Problem::FixedVectorSize ? Problem::VectorSizeA : GetVectorSizeA(); + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t VecLoadSize = GetVectorSizeA(); constexpr index_t NumWaveGroups = Problem::NumWaveGroups; using ALayout = remove_cvref_t< @@ -756,9 +755,7 @@ struct UniversalGemmBasePolicy // since the assumption is that A type is going to be the B LDS type constexpr bool IsBCastPolicyBeforeLDSWrite = IsBCastPolicyBeforeLDSWrite_v; constexpr index_t VecLoadSize = - IsBCastPolicyBeforeLDSWrite - ? (Problem::FixedVectorSize ? Problem::VectorSizeA : GetVectorSizeA()) - : (Problem::FixedVectorSize ? Problem::VectorSizeB : GetVectorSizeB()); + IsBCastPolicyBeforeLDSWrite ? GetVectorSizeA() : GetVectorSizeB(); constexpr index_t NumWaveGroups = Problem::NumWaveGroups; using BLayout = remove_cvref_t< std::tuple_element_t{}, remove_cvref_t>>;