Skip to content

Commit

Permalink
Fixed google drive downloading for pretrained models
Browse files Browse the repository at this point in the history
  • Loading branch information
wilson1yan committed Mar 2, 2022
1 parent c230bc3 commit 07a1f38
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 29 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ av
gradio
moviepy
imageio
gdown
31 changes: 2 additions & 29 deletions videogpt/download.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,21 @@
import requests
from tqdm import tqdm
import os
import gdown
import torch

from .vqvae import VQVAE
from .gpt import VideoGPT


def get_confirm_token(response):
for key, value in response.cookies.items():
if key.startswith('download_warning'):
return value
return None


def save_response_content(response, destination):
CHUNK_SIZE = 8192

pbar = tqdm(total=0, unit='iB', unit_scale=True)
with open(destination, 'wb') as f:
for chunk in response.iter_content(CHUNK_SIZE):
if chunk:
f.write(chunk)
pbar.update(len(chunk))
pbar.close()


def download(id, fname, root=os.path.expanduser('~/.cache/videogpt')):
os.makedirs(root, exist_ok=True)
destination = os.path.join(root, fname)

if os.path.exists(destination):
return destination

URL = 'https://drive.google.com/uc?export=download'
session = requests.Session()

response = session.get(URL, params={'id': id}, stream=True)
token = get_confirm_token(response)

if token:
params = {'id': id, 'confirm': token}
response = session.get(URL, params=params, stream=True)
save_response_content(response, destination)
gdown.download(id=id, output=destination, quiet=False)
return destination


Expand Down

0 comments on commit 07a1f38

Please sign in to comment.