Skip to content

Commit

Permalink
Add AI background removal
Browse files Browse the repository at this point in the history
  • Loading branch information
pierotofy committed Sep 19, 2022
1 parent deb5327 commit cd72000
Show file tree
Hide file tree
Showing 3 changed files with 138 additions and 0 deletions.
90 changes: 90 additions & 0 deletions opendm/bgfilter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@

import time
import numpy as np
import cv2
import os
import onnxruntime as ort
from opendm import log
from threading import Lock

mutex = Lock()

# Implementation based on https://github.com/danielgatis/rembg by Daniel Gatis

# Use GPU if it is available, otherwise CPU
provider = "CUDAExecutionProvider" if "CUDAExecutionProvider" in ort.get_available_providers() else "CPUExecutionProvider"

class BgFilter():
def __init__(self, model):
self.model = model

log.ODM_INFO(' ?> Using provider %s' % provider)
self.load_model()


def load_model(self):
log.ODM_INFO(' -> Loading the model')

self.session = ort.InferenceSession(self.model, providers=[provider])

def normalize(self, img, mean, std, size):
im = cv2.resize(img, size, interpolation=cv2.INTER_AREA)
im_ary = np.array(im)
im_ary = im_ary / np.max(im_ary)

tmpImg = np.zeros((im_ary.shape[0], im_ary.shape[1], 3))
tmpImg[:, :, 0] = (im_ary[:, :, 0] - mean[0]) / std[0]
tmpImg[:, :, 1] = (im_ary[:, :, 1] - mean[1]) / std[1]
tmpImg[:, :, 2] = (im_ary[:, :, 2] - mean[2]) / std[2]

tmpImg = tmpImg.transpose((2, 0, 1))

return {
self.session.get_inputs()[0]
.name: np.expand_dims(tmpImg, 0)
.astype(np.float32)
}

def get_mask(self, img):
height, width, c = img.shape

with mutex:
ort_outs = self.session.run(
None,
self.normalize(
img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (320, 320) # <-- image size
),
)

pred = ort_outs[0][:, 0, :, :]

ma = np.max(pred)
mi = np.min(pred)

pred = (pred - mi) / (ma - mi)
pred = np.squeeze(pred)

pred *= 255
pred = pred.astype("uint8")
output = cv2.resize(pred, (width, height), interpolation=cv2.INTER_LANCZOS4)
output[output > 127] = 255
output[output <= 127] = 0

return output

def run_img(self, img_path, dest):
img = cv2.imread(img_path, cv2.IMREAD_COLOR)
if img is None:
return None

img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
mask = self.get_mask(img)

img_name = os.path.basename(img_path)
fpath = os.path.join(dest, img_name)

fname, _ = os.path.splitext(fpath)
mask_name = fname + '_mask.png'
cv2.imwrite(mask_name, mask)

return mask_name
6 changes: 6 additions & 0 deletions opendm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,12 @@ def config(argv=None, parser=None):
nargs=0,
default=False,
help='Automatically compute image masks using AI to remove the sky. Experimental. Default: %(default)s')

parser.add_argument('--bg-removal',
action=StoreTrue,
nargs=0,
default=False,
help='Automatically compute image masks using AI to remove the background. Experimental. Default: %(default)s')

parser.add_argument('--use-3dmesh',
action=StoreTrue,
Expand Down
42 changes: 42 additions & 0 deletions stages/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from opendm import boundary
from opendm import ai
from opendm.skyremoval.skyfilter import SkyFilter
from opendm.bgfilter import BgFilter
from opendm.concurrency import parallel_map

def save_images_database(photos, database_file):
Expand Down Expand Up @@ -191,6 +192,47 @@ def parallel_sky_filter(item):

# End sky removal

# Automatic background removal
if args.bg_removal:
# For each image that :
# - Doesn't already have a mask, AND
# - There are no spaces in the image filename (OpenSfM requirement)

# Generate list of sky images
bg_images = []
for p in photos:
if p.mask is None and (not " " in p.filename):
bg_images.append({'file': os.path.join(images_dir, p.filename), 'p': p})

if len(bg_images) > 0:
log.ODM_INFO("Automatically generating background masks for %s images" % len(bg_images))
model = ai.get_model("bgremoval", "https://github.com/OpenDroneMap/ODM/releases/download/v2.9.0/u2net.zip", "v2.9.0")
if model is not None:
bg = BgFilter(model=model)

def parallel_bg_filter(item):
try:
mask_file = bg.run_img(item['file'], images_dir)

# Check and set
if mask_file is not None and os.path.isfile(mask_file):
item['p'].set_mask(os.path.basename(mask_file))
log.ODM_INFO("Wrote %s" % os.path.basename(mask_file))
else:
log.ODM_WARNING("Cannot generate mask for %s" % img)
except Exception as e:
log.ODM_WARNING("Cannot generate mask for %s: %s" % (img, str(e)))

parallel_map(parallel_bg_filter, bg_images, max_workers=args.max_concurrency)

log.ODM_INFO("Background masks generation completed!")
else:
log.ODM_WARNING("Cannot load AI model (you might need to be connected to the internet?)")
else:
log.ODM_INFO("No background masks will be generated (masks already provided)")

# End bg removal

# Save image database for faster restart
save_images_database(photos, images_database_file)
else:
Expand Down

0 comments on commit cd72000

Please sign in to comment.