diff --git a/model_test.py b/tests/model_test.py similarity index 99% rename from model_test.py rename to tests/model_test.py index d2490c6..6fc7825 100644 --- a/model_test.py +++ b/tests/model_test.py @@ -1,4 +1,3 @@ - from model import ModelContainer def progress(module, modules): diff --git a/tests/wheel_test.py b/tests/wheel_test.py new file mode 100644 index 0000000..d043b97 --- /dev/null +++ b/tests/wheel_test.py @@ -0,0 +1,47 @@ +import traceback +from importlib.metadata import version + +successful_packages = [] +errored_packages = [] + +try: + import flash_attn + print(f"Flash attention on version {version('flash_attn')} successfully imported") + successful_packages.append("flash_attn") +except: + print("Flash attention could not be loaded because:") + print(traceback.format_exc()) + errored_packages.append("flash_attn") + +try: + import exllamav2 + print(f"Exllamav2 on version {version('exllamav2')} successfully imported") + successful_packages.append("exllamav2") +except: + print("Exllamav2 could not be loaded because:") + print(traceback.format_exc()) + errored_packages.append("exllamav2") + +try: + import torch + print(f"Torch on version {version('torch')} successfully imported") + successful_packages.append("torch") +except: + print("Torch could not be loaded because:") + print(traceback.format_exc()) + errored_packages.append("torch") + +try: + import fastchat + print(f"Fastchat on version {version('fastchat')} successfully imported") + successful_packages.append("fastchat") +except: + print("Fastchat is only needed for chat completions with message arrays. Ignore this error if this isn't your usecase.") + print("Fastchat could not be loaded because:") + print(traceback.format_exc()) + errored_packages.append("fastchat") + +print( + f"\nSuccessful imports: {', '.join(successful_packages)}", + f"\nErrored imports: {''.join(errored_packages)}" +)