diff --git a/experimental/grouped_convolution_tile_instances/generate_instances.py b/experimental/grouped_convolution_tile_instances/generate_instances.py index 5620b62228..2651b16b3a 100755 --- a/experimental/grouped_convolution_tile_instances/generate_instances.py +++ b/experimental/grouped_convolution_tile_instances/generate_instances.py @@ -118,9 +118,9 @@ def get_k_mfma(dtype, m_per_xdl, n_per_xdl): return 4 else: if m_per_xdl == 32: - return 8 - else: return 16 + else: + return 32 def check_vectors(a_scalar_per_vector, b_scalar_per_vector, c_scalar_per_vector): @@ -444,7 +444,7 @@ def parse_fwd_instances(instances, problem_name): warp_size = 64 k_warp = int(block_size / (warp_size * m_warp * n_warp)) dtype = get_dtype(problem_name) - k_per_xdl = max(k1, get_k_mfma(dtype, m_per_xdl, n_per_xdl)) + k_per_xdl = min(max(k1, get_k_mfma(dtype, m_per_xdl, n_per_xdl)), k_per_block) if split_image: print( @@ -656,7 +656,7 @@ def parse_bwd_weight_instances(instances, problem_name): k_warp = int(block_size / (warp_size * m_warp * n_warp)) dtype = get_dtype(problem_name) - k_per_xdl = max(k1, get_k_mfma(dtype, m_per_xdl, n_per_xdl)) + k_per_xdl = min(max(k1, get_k_mfma(dtype, m_per_xdl, n_per_xdl)), k_per_block) if not check_vectors( a_scalar_per_vector, b_scalar_per_vector, c_scalar_per_vector @@ -792,7 +792,7 @@ def parse_bwd_data_instances(instances, problem_name): k_warp = int(block_size / (warp_size * m_warp * n_warp)) dtype = get_dtype(problem_name) - k_per_xdl = max(k1, get_k_mfma(dtype, m_per_xdl, n_per_xdl)) + k_per_xdl = min(max(k1, get_k_mfma(dtype, m_per_xdl, n_per_xdl)), k_per_block) if check_vectors(a_scalar_per_vector, b_scalar_per_vector, c_scalar_per_vector) == False: print(f"Skipping instance {instance_id} with irregular load since it's not supported yet.")