Skip to content

Commit

Permalink
Update apply preds
Browse files Browse the repository at this point in the history
  • Loading branch information
Alyetama committed Mar 14, 2022
1 parent 1e164a8 commit 1348a08
Show file tree
Hide file tree
Showing 4 changed files with 653 additions and 58 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -165,3 +165,5 @@ class_names.npy
_ignore/
weights/
tasks_latest.json
picam/
Untitled.ipynb
173 changes: 118 additions & 55 deletions apply_predictions.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,36 @@
import imghdr
import json
import os
import shutil
import signal
import sys
from glob import glob
from pathlib import Path

import numpy as np
import requests
from dotenv import load_dotenv
from PIL import Image
from requests.structures import CaseInsensitiveDict
from loguru import logger
from tqdm import tqdm

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

import model_predict
from mongodb_helpers import get_mongodb_data


def keyboard_interrupt_handler(sig, frame):
logger.info(f'KeyboardInterrupt (ID: {sig}) has been caught...')
sys.exit(0)


def mkdirs():
Path('tmp').mkdir(exist_ok=True)
Path('tmp/downloaded').mkdir(exist_ok=True)
Path('tmp/cropped').mkdir(exist_ok=True)


def make_headers():
load_dotenv()
TOKEN = os.environ['TOKEN']
Expand All @@ -24,10 +40,9 @@ def make_headers():
return headers


def get_all_tasks(project_id):
def get_all_tasks(headers, project_id):
logger.debug('Getting tasks data... This might take few minutes...')
url = f"https://ls.aibird.me/api/projects/{project_id}/tasks?page_size=10000"
headers = make_headers()
resp = requests.get(url,
headers=headers,
data=json.dumps({'project': project_id}))
Expand All @@ -51,30 +66,89 @@ def predict(image_path):
return pred, prob


def load_local_image(img_path):
"""https://github.com/microsoft/CameraTraps/blob/main/classification/crop_detections.py"""
try:
with Image.open(img_path) as img:
img.load()
return img
except OSError as e:
exception_type = type(e).__name__
logger.error(f'Unable to load {img_path}. {exception_type}: {e}.')
return None


def save_crop(img, bbox_norm, square_crop, save):
"""https://github.com/microsoft/CameraTraps/blob/main/classification/crop_detections.py"""
img_w, img_h = img.size
xmin = int(bbox_norm[0] * img_w)
ymin = int(bbox_norm[1] * img_h)
box_w = int(bbox_norm[2] * img_w)
box_h = int(bbox_norm[3] * img_h)

if square_crop:
box_size = max(box_w, box_h)
xmin = max(0, min(xmin - int((box_size - box_w) / 2), img_w - box_w))
ymin = max(0, min(ymin - int((box_size - box_h) / 2), img_h - box_h))
box_w = min(img_w, box_size)
box_h = min(img_h, box_size)

if box_w == 0 or box_h == 0:
tqdm.write(f'Skipping size-0 crop (w={box_w}, h={box_h}) at {save}')
return False

crop = img.crop(box=[xmin, ymin, xmin + box_w,
ymin + box_h]) # [left, upper, right, lower]

if square_crop and (box_w != box_h):
crop = ImageOps.pad(crop, size=(box_size, box_size), color=0)
crop.save(save)
return os.path.dirname(save)


def main(task_id):
headers = make_headers()
url = f"https://ls.aibird.me/api/tasks/{task_id}"

resp = requests.get(url, headers=headers)
task_ = resp.json()
if not task_['predictions']:
img = task_['data']['image']
else:
if task_['predictions']:
return
img_in_task = task_['data']['image']

url = task_['data']['image'].replace('ls.aibird.me/data/local-files/?d=',
'srv.aibird.me/')
img_name = Path(img_in_task).name
img_relative_path = f'tmp/downloaded/{img_name}'

bbox_res = find_image(img_name)

r = requests.get(url)
with open(img_relative_path, 'wb') as f:
f.write(r.content)

if not imghdr.what(img_relative_path):
logger.error(f'Not a valid image file: {img_relative_path}')
return
md_preds = find_image(Path(img).name)

img = load_local_image(img_relative_path)

results = []
scores = []
for item in md_preds['detections']:
if item['category'] != '1':
bboxes = []

for n, task in enumerate(bbox_res['detections']):
if task['category'] != '1':
continue
x, y, width, height = [x * 100 for x in item['bbox']]

for img_tuple in images:
if img_tuple[0] == Path(img).name:
logger.debug(Path(img).name)
pred, prob = predict(img_tuple[1])
scores.append(prob)
break
bboxes.append([n, task['bbox']])
out_cropped = f'tmp/cropped/{Path(img_name).stem}_{bboxes[0][0]}.jpg'
save_crop(img, bboxes[0][1], False, out_cropped)

pred, prob = predict(out_cropped)

x, y, width, height = [x * 100 for x in task['bbox']]

scores.append(prob)
results.append({
'from_name': 'label',
'to_name': 'image',
Expand All @@ -90,52 +164,41 @@ def main(task_id):
})

post_ = {
"model_version": "Megadetector",
"result": results,
"score": np.mean(scores),
"cluster": 0,
"neighbors": {},
"mislabeling": 0,
"task": task_id
'model_version': 'picam-detector_1647175692',
'result': results,
'score': np.mean(scores),
'cluster': 0,
'neighbors': {},
'mislabeling': 0,
'task': task_id
}

url = "https://ls.aibird.me/api/predictions/"
resp = requests.post(url, headers=headers, data=json.dumps(post_))
logger.debug(resp.json())
return resp


if __name__ == '__main__':
logger.add('logs/apply_predictions.log')

class_names = 'class_names.npy'
pretrained_weights = 'weights/1647175692.h5'

images = glob('dataset_cropped/**/*.jpg', recursive=True)
images = [(Path(x).name, x) for x in images]

# if len(sys.argv) == 1:
# raise Exception('You need to provide a path to the output data file!')
# if not Path(sys.argv[1]).exists():
# raise FileNotFoundError('The path you entered does not exist!')
# md_data_file = sys.argv[1]

# with open(md_data_file) as j:
# md_data = json.load(j)
logger.add('apply_predictions.log')
signal.signal(signal.SIGINT, keyboard_interrupt_handler)

md_data = get_mongodb_data()
headers = make_headers()
mkdirs()

data = get_all_tasks(project_id=6)

tasks_ids = [x['id'] for x in data]

i = 0
for cur_task in tqdm(tasks_ids):
try:
out = main(cur_task)
if out:
i += 1
except KeyboardInterrupt:
sys.exit('Interrupted by the user...')

logger.info(f'Total number of predictions applied: {i}')
class_names = 'class_names.npy'
pretrained_weights = 'weights/03_14_2022__1647175692.h5'
project_id = 8

project_tasks = get_all_tasks(headers, project_id)
tasks_ids = [t_['id'] for t_ in project_tasks]

logger.debug('Starting prediction...')
try:
for task_id in tqdm(tasks_ids):
main(task_id)
except Exception as e:
logger.exception(e)
finally:
shutil.rmtree('tmp')
sys.exit(0)
Loading

0 comments on commit 1348a08

Please sign in to comment.