diff --git a/.gitignore b/.gitignore index a2fb1473ab..17f93500bd 100644 --- a/.gitignore +++ b/.gitignore @@ -103,6 +103,7 @@ dispatcher/**/dispatcher_kernels.json test_data/* !test_data/*.py !test_data/*.sh +!test_data/requirements.txt # Exceptions to build* patterns above # The experimental/builder directory should be tracked despite matching build* diff --git a/test_data/generate_test_dataset.sh b/test_data/generate_test_dataset.sh index 27f45a3bc7..5cbc5514e6 100755 --- a/test_data/generate_test_dataset.sh +++ b/test_data/generate_test_dataset.sh @@ -50,10 +50,8 @@ if ! python3 -c "import torch" 2>/dev/null; then # Install PyTorch in virtual environment with ROCm support echo "Installing PyTorch and torchvision with ROCm support in virtual environment..." - # Since we're in a ROCm 6.4.1 environment, we need compatible PyTorch - # PyTorch doesn't have 6.4 wheels yet, so we use 6.2 which should be compatible - echo "Installing PyTorch with ROCm 6.2 support (compatible with ROCm 6.4)..." - pip install torch==2.5.1 torchvision==0.20.1 --index-url https://download.pytorch.org/whl/rocm6.2 || { + echo "Installing PyTorch with ROCm 7.1 support..." + pip install -r requirements.txt || { echo "ERROR: Failed to install PyTorch with ROCm support." echo "Creating empty CSV files as fallback..." echo "# 2D Convolution Test Cases" > conv_test_set_2d_dataset.csv diff --git a/test_data/requirements.txt b/test_data/requirements.txt new file mode 100644 index 0000000000..ecf05539f5 --- /dev/null +++ b/test_data/requirements.txt @@ -0,0 +1,3 @@ +-i https://download.pytorch.org/whl/rocm7.1 +torch==2.10.* +torchvision==0.25.* \ No newline at end of file