From b1d99214143bcb20a3ce2ef52700fc33e7418937 Mon Sep 17 00:00:00 2001 From: Ailing Zhang Date: Thu, 8 Aug 2019 18:57:39 -0700 Subject: [PATCH] fix urllib error on colab --- hubconf.py | 29 +++++++++++++---------------- 1 file changed, 13 insertions(+), 16 deletions(-) diff --git a/hubconf.py b/hubconf.py index 66788b84a..a248f6cab 100644 --- a/hubconf.py +++ b/hubconf.py @@ -34,6 +34,15 @@ def unwrap_distributed(state_dict): new_state_dict[new_key] = value return new_state_dict +def _download_checkpoint(checkpoint, force_reload): + model_dir = os.path.join(torch.hub._get_torch_home(), 'checkpoints') + if not os.path.exists(model_dir): + os.makedirs(model_dir) + ckpt_file = os.path.join(model_dir, os.path.basename(checkpoint)) + if not os.path.exists(ckpt_file) or force_reload: + sys.stderr.write('Downloading checkpoint from {}\n'.format(checkpoint)) + urllib.request.urlretrieve(checkpoint, ckpt_file) + return ckpt_file dependencies = ['torch'] @@ -66,10 +75,7 @@ def nvidia_ncf(pretrained=True, **kwargs): checkpoint = 'https://developer.nvidia.com/joc-ncf-fp16-pyt-20190225' else: checkpoint = 'https://developer.nvidia.com/joc-ncf-fp32-pyt-20190225' - ckpt_file = os.path.join(torch.hub._get_torch_home(), "checkpoints", os.path.basename(checkpoint)) - if not os.path.exists(ckpt_file) or force_reload: - sys.stderr.write('Downloading checkpoint from {}\n'.format(checkpoint)) - urllib.request.urlretrieve(checkpoint, ckpt_file) + ckpt_file = _download_checkpoint(checkpoint, force_reload) ckpt = torch.load(ckpt_file) if checkpoint_from_distributed(ckpt): @@ -130,10 +136,7 @@ def nvidia_tacotron2(pretrained=True, **kwargs): checkpoint = 'https://developer.nvidia.com/joc-tacotron2-fp16-pyt-20190306' else: checkpoint = 'https://developer.nvidia.com/joc-tacotron2-fp32-pyt-20190306' - ckpt_file = os.path.join(torch.hub._get_torch_home(), "checkpoints", os.path.basename(checkpoint)) - if not os.path.exists(ckpt_file) or force_reload: - sys.stderr.write('Downloading checkpoint from {}\n'.format(checkpoint)) - urllib.request.urlretrieve(checkpoint, ckpt_file) + ckpt_file = _download_checkpoint(checkpoint, force_reload) ckpt = torch.load(ckpt_file) state_dict = ckpt['state_dict'] if checkpoint_from_distributed(state_dict): @@ -190,10 +193,7 @@ def nvidia_waveglow(pretrained=True, **kwargs): checkpoint = 'https://developer.nvidia.com/joc-waveglow-fp16-pyt-20190306' else: checkpoint = 'https://developer.nvidia.com/joc-waveglow-fp32-pyt-20190306' - ckpt_file = os.path.join(torch.hub._get_torch_home(), "checkpoints", os.path.basename(checkpoint)) - if not os.path.exists(ckpt_file) or force_reload: - sys.stderr.write('Downloading checkpoint from {}\n'.format(checkpoint)) - urllib.request.urlretrieve(checkpoint, ckpt_file) + ckpt_file = _download_checkpoint(checkpoint, force_reload) ckpt = torch.load(ckpt_file) state_dict = ckpt['state_dict'] if checkpoint_from_distributed(state_dict): @@ -360,10 +360,7 @@ def batchnorm_to_float(module): checkpoint = 'https://developer.nvidia.com/joc-ssd-fp16-pyt-20190225' else: checkpoint = 'https://developer.nvidia.com/joc-ssd-fp32-pyt-20190225' - ckpt_file = os.path.join(torch.hub._get_torch_home(), "checkpoints", os.path.basename(checkpoint)) - if not os.path.exists(ckpt_file) or force_reload: - sys.stderr.write('Downloading checkpoint from {}\n'.format(checkpoint)) - urllib.request.urlretrieve(checkpoint, ckpt_file) + ckpt_file = _download_checkpoint(checkpoint, force_reload) ckpt = torch.load(ckpt_file) ckpt = ckpt['model'] if checkpoint_from_distributed(ckpt):