mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-20 04:49:54 +00:00
[CK_TILE] Add conv bwd data tests (#5646)
## Motivation This PR adds tests for CK Tile's convolution backward data operation to enable functionality regression tracking and error-detection. ## Technical Details Currently only NHWGC/GKCYX/NHWGK and NDHWGC/GKCZYX/NDHWGK(2 dim and 3 dim channel-last) layouts are being tested, since only they are implemented in CK Tile. Current tests support FP16, BF16 and FP32 datatypes and various different convolutions scenarios. The tested instances are listed in `experimental/grouped_convolution_tile_instances` directory. ## Test Result All implemented tests are working properly and passing. --------- Co-authored-by: Ville Pietilä <> Co-authored-by: Ville Pietilä <188998872+vpietila-amd@users.noreply.github.com> Co-authored-by: Jakub Piasecki <jakpia21@gmail.com>
This commit is contained in:
@@ -586,14 +586,12 @@ def parse_bwd_data_instances(instances, problem_name):
|
||||
if pipeline_version == "V6":
|
||||
print(f"Skipping instance {instance_id} with V6 since it's not supported yet.")
|
||||
continue
|
||||
|
||||
# Check vector sizes for A and B tensors - we cannot oversubscribe.
|
||||
num_tile_elements_a = m_per_xdl * k_per_xdl
|
||||
num_tile_elements_b = n_per_xdl * k_per_xdl
|
||||
max_vector_size_a = max(1, num_tile_elements_a // block_size)
|
||||
max_vector_size_b = max(1, num_tile_elements_b // block_size)
|
||||
a_scalar_per_vector = min(a_scalar_per_vector, max_vector_size_a)
|
||||
b_scalar_per_vector = min(b_scalar_per_vector, max_vector_size_b)
|
||||
if k_per_block > (warp_size * a_scalar_per_vector) or n_per_block > (warp_size * b_scalar_per_vector):
|
||||
print(f"Skipping instance {instance_id} with multiple warps per continous tile dim since it's not supported yet.")
|
||||
continue
|
||||
if a_scalar_per_vector > (m_per_block * k_per_block) // block_size or b_scalar_per_vector > (n_per_block * k_per_block) // block_size:
|
||||
print(f"Skipping instance {instance_id} because current scalar per vector exceedes tile size")
|
||||
continue
|
||||
|
||||
conv = ConvInstanceTemplateParams(
|
||||
spec,
|
||||
|
||||
Reference in New Issue
Block a user