From 04ed7d9ba9633367a16c87a3b656e8741e262024 Mon Sep 17 00:00:00 2001 From: Johannes Graner Date: Wed, 28 Jan 2026 15:38:10 +0100 Subject: [PATCH] Update pytorch version in convolution dataset test generation (#3667) * Update torch version in dataset test gen [ROCm/composable_kernel commit: bc6083bdd466d1e060253e7a49626c923293c483] --- .gitignore | 1 + test_data/generate_test_dataset.sh | 6 ++---- test_data/requirements.txt | 3 +++ 3 files changed, 6 insertions(+), 4 deletions(-) create mode 100644 test_data/requirements.txt diff --git a/.gitignore b/.gitignore index a2fb1473ab..17f93500bd 100644 --- a/.gitignore +++ b/.gitignore @@ -103,6 +103,7 @@ dispatcher/**/dispatcher_kernels.json test_data/* !test_data/*.py !test_data/*.sh +!test_data/requirements.txt # Exceptions to build* patterns above # The experimental/builder directory should be tracked despite matching build* diff --git a/test_data/generate_test_dataset.sh b/test_data/generate_test_dataset.sh index 27f45a3bc7..5cbc5514e6 100755 --- a/test_data/generate_test_dataset.sh +++ b/test_data/generate_test_dataset.sh @@ -50,10 +50,8 @@ if ! python3 -c "import torch" 2>/dev/null; then # Install PyTorch in virtual environment with ROCm support echo "Installing PyTorch and torchvision with ROCm support in virtual environment..." - # Since we're in a ROCm 6.4.1 environment, we need compatible PyTorch - # PyTorch doesn't have 6.4 wheels yet, so we use 6.2 which should be compatible - echo "Installing PyTorch with ROCm 6.2 support (compatible with ROCm 6.4)..." - pip install torch==2.5.1 torchvision==0.20.1 --index-url https://download.pytorch.org/whl/rocm6.2 || { + echo "Installing PyTorch with ROCm 7.1 support..." + pip install -r requirements.txt || { echo "ERROR: Failed to install PyTorch with ROCm support." echo "Creating empty CSV files as fallback..." echo "# 2D Convolution Test Cases" > conv_test_set_2d_dataset.csv diff --git a/test_data/requirements.txt b/test_data/requirements.txt new file mode 100644 index 0000000000..ecf05539f5 --- /dev/null +++ b/test_data/requirements.txt @@ -0,0 +1,3 @@ +-i https://download.pytorch.org/whl/rocm7.1 +torch==2.10.* +torchvision==0.25.* \ No newline at end of file