Skip to content

Instantly share code, notes, and snippets.

@curioyang
Forked from peterjc123/check_import.py
Created September 25, 2024 02:51
Show Gist options
  • Save curioyang/98bd6a9f5f800a53ad2589109164e984 to your computer and use it in GitHub Desktop.
Save curioyang/98bd6a9f5f800a53ad2589109164e984 to your computer and use it in GitHub Desktop.
Help to detect import errors for PyTorch on Windows
# This script tries to figure out the reason of
# `ImportError` on Windows.
# Run it with `python check_import.py`.
import ctypes
import glob
import os
import sys
import subprocess
def infer_torch_root():
default_path = os.path.join(sys.exec_prefix, 'Lib\\site-packages\\torch')
for candidate in sys.path:
candidate_path = os.path.abspath(candidate)
candidate_path = os.path.join(candidate_path, 'torch')
if os.path.exists(candidate_path):
candidate_path = candidate_path
return candidate_path
return default_path
PY3 = sys.version_info >= (3, 0)
if sys.version_info >= (3, 6, 0) and sys.version_info <= (3, 6, 1):
print('Please update your python to 3.6.X (X>0) first.')
exit(1)
TORCH_ROOT = infer_torch_root()
PY_DLL_PATH = os.path.join(os.path.dirname(sys.executable), 'Library\\bin')
TORCH_DLL_PATH = os.path.join(TORCH_ROOT, 'lib')
NVTOOLEXT_HOME = os.getenv(
'NVTOOLSEXT_PATH', 'C:\\Program Files\\NVIDIA Corporation\\NvToolsExt')
NV_ROOT = os.path.dirname(NVTOOLEXT_HOME)
IS_CUDA = len(glob.glob(TORCH_ROOT + '\\_nvrtc*.pyd')) > 0
IS_CONDA = 'conda' in sys.version or 'Continuum' in sys.version or any(
[x.startswith('CONDA') for x in os.environ])
VC_LIBS = ['msvcp140.dll']
MKL_LIBS = ['mkl_rt.dll']
INTEL_OPENMP_LIBS = ['libiomp5md.dll']
CUDA_LIBS = ['nvcuda.dll',
'nvToolsExt64_1.dll',
'nvfatbinaryLoader.dll']
TORCH_LIBS = ['c10.dll']
def add_paths(paths):
"""Add paths to `PATH`"""
for path in paths:
os.environ['PATH'] = path + ';' + os.environ['PATH']
def get_output(command):
"""Returns stdout if rc is not 0 else None"""
p = subprocess.Popen(command, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, shell=True)
output, _ = p.communicate()
rc = p.returncode
if PY3:
output = output.decode("ascii")
if rc is not 0:
return None
return output.strip()
def get_file_path(filename):
"""Returns the path of a file in `PATH`"""
out = get_output('where ' + filename)
if out is None:
return out
else:
return out.split('\r\n')[0]
def detect_reason(raw_message):
message = ''
if raw_message == 'No module named torch' or raw_message == 'No module named \'torch\'':
# detect pip python path mismatch
mismatch, pip_path, python_path = detect_install_import_mismatch()
if mismatch:
message += 'Probably you installed torch in one environment '
message += 'but imported in another one.\n'
message += 'Detected pip path: %s\n' % pip_path
message += 'Detected python path: %s\n' % python_path
else:
message += 'It seems that torch is not installed.\n'
message += 'Please refer to https://pytorch.org for installation.\n'
elif raw_message.startswith('DLL load failed'):
add_paths([NVTOOLEXT_HOME, PY_DLL_PATH])
message += check_dependents(TORCH_LIBS, 'PyTorch', [
'1. Please change your current directory.', '2. Please reinstall torch.'])
message += check_dependents(VC_LIBS, 'VC Redist',
'Please refer to https://aka.ms/vs/15/release/VC_redist.x64.exe for installation.')
message += check_dependents(MKL_LIBS, 'MKL',
'`conda install mkl` or `pip install mkl`')
message += check_dependents(INTEL_OPENMP_LIBS, 'intel-openmp',
'`conda install intel-openmp` or `pip install intel-openmp`')
if IS_CUDA:
if detect_nv_card():
message += check_dependents(
CUDA_LIBS, 'CUDA', 'Please refer to https://developer.nvidia.com/cuda-downloads for installation.')
else:
message += 'It seems that you don\'t have NV cards. Please use CPU version instead.\n'
if message == '':
message += 'It seems `import torch` should work.'
message += 'You may try to add `%s` to the environment variable `PATH`.\n' % PY_DLL_PATH
message += 'And make sure you restart the command prompt when you apply any changes to the environment.\n'
else:
message += 'Sorry, we don\'t support this kind of message at present.\n'
message += 'Original message:\n'
message += raw_message
return message
def detect_install_import_mismatch():
pip_path = get_file_path('pip.exe')
python_path = sys.executable
if pip_path is None or python_path is None:
return False
pip_dir = os.path.dirname(pip_path)
python_dir = os.path.dirname(python_path)
pip_parent_path = os.path.normpath(os.path.dirname(pip_dir))
python_path = os.path.normpath(python_dir)
mismatch = pip_parent_path != python_path
return mismatch, pip_path, python_path
def check_dependents(dependents, name, solution):
"""Checks dependencies loading and prints name and solution"""
message = ''
for dll in dependents:
try:
_ = ctypes.CDLL(dll)
except Exception as e:
message += 'DLL loading %s failed\n' % dll
message += 'Original error message:\n'
message += str(e) + '\n'
if name is not None:
message += 'It is a component of %s\n' % name
if solution:
message += 'Possible solution:\n'
if isinstance(solution, list):
message += '\n'.join(solution)
else:
message += solution
message += '\n'
return message
def detect_nv_card():
gpu_names = get_output('wmic path win32_VideoController get name')
if 'NVIDIA' in gpu_names:
return True
else:
return False
def main():
try:
import torch
print('`import torch` works perfectly.')
except ImportError as e:
message = detect_reason(str(e))
print(message)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment