From 07a1f38e95d9a9b1440f900d5e3efd17b8cc627c Mon Sep 17 00:00:00 2001 From: Wilson Yan Date: Wed, 2 Mar 2022 17:37:09 +0000 Subject: [PATCH] Fixed google drive downloading for pretrained models --- requirements.txt | 1 + videogpt/download.py | 31 ++----------------------------- 2 files changed, 3 insertions(+), 29 deletions(-) diff --git a/requirements.txt b/requirements.txt index 32215a6..506ab1c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,3 +10,4 @@ av gradio moviepy imageio +gdown diff --git a/videogpt/download.py b/videogpt/download.py index fbfa640..0d7fe23 100644 --- a/videogpt/download.py +++ b/videogpt/download.py @@ -1,31 +1,13 @@ import requests from tqdm import tqdm import os +import gdown import torch from .vqvae import VQVAE from .gpt import VideoGPT -def get_confirm_token(response): - for key, value in response.cookies.items(): - if key.startswith('download_warning'): - return value - return None - - -def save_response_content(response, destination): - CHUNK_SIZE = 8192 - - pbar = tqdm(total=0, unit='iB', unit_scale=True) - with open(destination, 'wb') as f: - for chunk in response.iter_content(CHUNK_SIZE): - if chunk: - f.write(chunk) - pbar.update(len(chunk)) - pbar.close() - - def download(id, fname, root=os.path.expanduser('~/.cache/videogpt')): os.makedirs(root, exist_ok=True) destination = os.path.join(root, fname) @@ -33,16 +15,7 @@ def download(id, fname, root=os.path.expanduser('~/.cache/videogpt')): if os.path.exists(destination): return destination - URL = 'https://drive.google.com/uc?export=download' - session = requests.Session() - - response = session.get(URL, params={'id': id}, stream=True) - token = get_confirm_token(response) - - if token: - params = {'id': id, 'confirm': token} - response = session.get(URL, params=params, stream=True) - save_response_content(response, destination) + gdown.download(id=id, output=destination, quiet=False) return destination