From 07a1f38e95d9a9b1440f900d5e3efd17b8cc627c Mon Sep 17 00:00:00 2001
From: Wilson Yan
Date: Wed, 2 Mar 2022 17:37:09 +0000
Subject: [PATCH] Fixed google drive downloading for pretrained models
---
requirements.txt | 1 +
videogpt/download.py | 31 ++-----------------------------
2 files changed, 3 insertions(+), 29 deletions(-)
diff --git a/requirements.txt b/requirements.txt
index 32215a6..506ab1c 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -10,3 +10,4 @@ av
gradio
moviepy
imageio
+gdown
diff --git a/videogpt/download.py b/videogpt/download.py
index fbfa640..0d7fe23 100644
--- a/videogpt/download.py
+++ b/videogpt/download.py
@@ -1,31 +1,13 @@
import requests
from tqdm import tqdm
import os
+import gdown
import torch
from .vqvae import VQVAE
from .gpt import VideoGPT
-def get_confirm_token(response):
- for key, value in response.cookies.items():
- if key.startswith('download_warning'):
- return value
- return None
-
-
-def save_response_content(response, destination):
- CHUNK_SIZE = 8192
-
- pbar = tqdm(total=0, unit='iB', unit_scale=True)
- with open(destination, 'wb') as f:
- for chunk in response.iter_content(CHUNK_SIZE):
- if chunk:
- f.write(chunk)
- pbar.update(len(chunk))
- pbar.close()
-
-
def download(id, fname, root=os.path.expanduser('~/.cache/videogpt')):
os.makedirs(root, exist_ok=True)
destination = os.path.join(root, fname)
@@ -33,16 +15,7 @@ def download(id, fname, root=os.path.expanduser('~/.cache/videogpt')):
if os.path.exists(destination):
return destination
- URL = 'https://drive.google.com/uc?export=download'
- session = requests.Session()
-
- response = session.get(URL, params={'id': id}, stream=True)
- token = get_confirm_token(response)
-
- if token:
- params = {'id': id, 'confirm': token}
- response = session.get(URL, params=params, stream=True)
- save_response_content(response, destination)
+ gdown.download(id=id, output=destination, quiet=False)
return destination