-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
324 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -158,3 +158,6 @@ dataset_cropped/ | |
*.tar | ||
project-*.json | ||
not_downloaded.json | ||
logs/ | ||
saved_models/ | ||
class_names.npy |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
# >>>>>>>> ONLY APPLIES BOUNDING BOX AT THE MOMENT!! | ||
|
||
import json | ||
import os | ||
import sys | ||
from pathlib import Path | ||
|
||
import numpy as np | ||
import requests | ||
from dotenv import load_dotenv | ||
from requests.structures import CaseInsensitiveDict | ||
from loguru import logger | ||
from tqdm import tqdm | ||
|
||
|
||
def make_headers(): | ||
load_dotenv() | ||
TOKEN = os.environ['TOKEN'] | ||
headers = CaseInsensitiveDict() | ||
headers["Content-type"] = "application/json" | ||
headers["Authorization"] = f"Token {TOKEN}" | ||
return headers | ||
|
||
|
||
def get_all_tasks(project_id): | ||
logger.debug('Getting tasks data... This might take few minutes...') | ||
url = f"https://ls.aibird.me/api/projects/{project_id}/tasks?page_size=10000" | ||
headers = make_headers() | ||
resp = requests.get(url, | ||
headers=headers, | ||
data=json.dumps({'project': project_id})) | ||
with open('tasks_latest.json', 'w') as j: | ||
json.dump(resp.json(), j) | ||
return resp.json() | ||
|
||
|
||
def find_image(img_name): | ||
for im in md_data['images']: | ||
if Path(im['file']).name == img_name: | ||
return im | ||
|
||
|
||
def main(task_id): | ||
headers = make_headers() | ||
url = f"https://ls.aibird.me/api/tasks/{task_id}" | ||
resp = requests.get(url, headers=headers) | ||
task_ = resp.json() | ||
if not task_['predictions']: | ||
img = task_['data']['image'] | ||
else: | ||
return | ||
md_preds = find_image(Path(img).name) | ||
|
||
results = [] | ||
scores = [] | ||
for item in md_preds['detections']: | ||
if item['category'] != '1': | ||
continue | ||
x, y, width, height = [x * 100 for x in item['bbox']] | ||
scores.append(item['conf']) | ||
results.append({ | ||
'from_name': 'label', | ||
'to_name': 'image', | ||
'type': 'rectanglelabels', | ||
'value': { | ||
'rectanglelabels': ['object'], | ||
'x': x, | ||
'y': y, | ||
'width': width, | ||
'height': height | ||
}, | ||
'score': item['conf'] | ||
}) | ||
|
||
post_ = { | ||
"model_version": "Megadetector", | ||
"result": results, | ||
"score": np.mean(scores), | ||
"cluster": 0, | ||
"neighbors": {}, | ||
"mislabeling": 0, | ||
"task": task_id | ||
} | ||
|
||
url = "https://ls.aibird.me/api/predictions/" | ||
resp = requests.post(url, headers=headers, data=json.dumps(post_)) | ||
logger.debug(resp.json()) | ||
return resp | ||
|
||
|
||
if __name__ == '__main__': | ||
logger.add('apply_predictions.log') | ||
# md_data_file = 'data_.json' | ||
if len(sys.argv) == 1: | ||
raise Exception('You need to provide a path to the output data file!') | ||
if not Path(sys.argv[1]).exists(): | ||
raise FileNotFoundError('The path you entered does not exist!') | ||
md_data_file = sys.argv[1] | ||
with open(md_data_file) as j: | ||
md_data = json.load(j) | ||
|
||
data = get_all_tasks(project_id=1) | ||
# with open('tasks_latest.json') as j: | ||
# data = json.load(j) | ||
tasks_ids = [x['id'] for x in data] | ||
|
||
i = 0 | ||
for cur_task in tqdm(tasks_ids): | ||
try: | ||
out = main(cur_task) | ||
if out: | ||
i += 1 | ||
except KeyboardInterrupt: | ||
sys.exit('Interrupted by the user...') | ||
logger.info(f'Total number of predictions applied: {i}') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import hashlib | ||
import json | ||
import shutil | ||
from glob import glob | ||
from pathlib import Path | ||
|
||
from loguru import logger | ||
from tqdm import tqdm | ||
|
||
|
||
def gen_hash(file_path): | ||
with open(file_path, 'rb') as f: | ||
content = f.read() | ||
results = hashlib.md5(content).hexdigest() | ||
return results | ||
|
||
|
||
def main(root_path='local-files/picam'): | ||
files = glob(f'{root_path}/**/*.jpg', recursive=True) | ||
with open('hashes.json') as j: | ||
stored_hashes = json.load(j) | ||
shutil.copy2('hashes.json', '_hashes.json') | ||
|
||
hashes = [] | ||
i = 0 | ||
for file in tqdm(files): | ||
_hash = {Path(file).name: gen_hash(file)} | ||
if _hash not in stored_hashes: | ||
i += 1 | ||
hashes.append(_hash) | ||
logger.debug(f'Hashed {i} new images') | ||
|
||
with open('hashes.json', 'w') as j: | ||
json.dump(stored_hashes + hashes, j) | ||
|
||
Path('_hashes.json').unlink() | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
import os | ||
import datetime | ||
import time | ||
import warnings | ||
|
||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' | ||
warnings.simplefilter(action='ignore', category=FutureWarning) | ||
|
||
import PIL.Image as Image | ||
import numpy as np | ||
import matplotlib | ||
|
||
matplotlib.use('agg') | ||
import matplotlib.pylab as plt | ||
import tensorflow as tf | ||
import tensorflow_hub as hub | ||
from loguru import logger | ||
|
||
|
||
def load_pretrained(pretrained_model='mobilenet_v2'): | ||
if pretrained_model == 'mobilenet_v2': | ||
classifier_model = 'https://tfhub.dev/google/tf2-preview/mobilenet_v2/classification/4' | ||
elif pretrained_model == 'inception_v3': | ||
classifier_model = 'https://tfhub.dev/google/imagenet/inception_v3/classification/5' | ||
return classifier_model | ||
|
||
|
||
def prepare_dataset(classifier_model, data_root='dataset_cropped'): | ||
IMAGE_SHAPE = (224, 224) | ||
classifier = tf.keras.Sequential( | ||
[hub.KerasLayer(classifier_model, input_shape=IMAGE_SHAPE + (3, ))]) | ||
|
||
batch_size = 32 | ||
img_height = 224 | ||
img_width = 224 | ||
|
||
train_ds = tf.keras.utils.image_dataset_from_directory( | ||
str(data_root), | ||
validation_split=0.2, | ||
subset='training', | ||
seed=123, | ||
image_size=(img_height, img_width), | ||
batch_size=batch_size) | ||
val_ds = tf.keras.utils.image_dataset_from_directory( | ||
str(data_root), | ||
validation_split=0.2, | ||
subset='validation', | ||
seed=123, | ||
image_size=(img_height, img_width), | ||
batch_size=batch_size) | ||
|
||
class_names = np.array(train_ds.class_names) | ||
logger.debug(class_names) | ||
|
||
normalization_layer = tf.keras.layers.Rescaling(1. / 255) | ||
train_ds = train_ds.map(lambda x, y: (normalization_layer(x), y)) | ||
val_ds = val_ds.map(lambda x, y: (normalization_layer(x), y)) | ||
|
||
AUTOTUNE = tf.data.AUTOTUNE | ||
train_ds = train_ds.cache().prefetch(buffer_size=AUTOTUNE) | ||
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE) | ||
return train_ds, val_ds, class_names | ||
|
||
|
||
def feature_extractor(train_ds, loaded_feature_extractor_model='mobilenet_v2'): | ||
for image_batch, labels_batch in train_ds: | ||
logger.debug(image_batch.shape) | ||
logger.debug(labels_batch.shape) | ||
break | ||
if loaded_feature_extractor_model == 'mobilenet_v2': | ||
feature_extractor_model = 'https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4' | ||
elif loaded_feature_extractor_model == 'mobilenet_v2': | ||
feature_extractor_model = 'https://tfhub.dev/google/tf2-preview/inception_v3/feature_vector/4' | ||
|
||
feature_extractor_layer = hub.KerasLayer(feature_extractor_model, | ||
input_shape=(224, 224, 3), | ||
trainable=False) | ||
|
||
feature_batch = feature_extractor_layer(image_batch) | ||
logger.debug(f'feature_batch.shape: {feature_batch.shape}') | ||
|
||
num_classes = len(class_names) | ||
logger.debug(f'num_classes: {num_classes}') | ||
return feature_extractor_layer, image_batch | ||
|
||
|
||
def build_model(feature_extractor_layer, image_batch): | ||
model = tf.keras.Sequential( | ||
[feature_extractor_layer, | ||
tf.keras.layers.Dense(num_classes)]) | ||
model.summary() | ||
predictions = model(image_batch) | ||
logger.debug(predictions.shape) | ||
|
||
model.compile( | ||
optimizer=tf.keras.optimizers.Adam(), | ||
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), | ||
metrics=['acc']) | ||
return model | ||
|
||
|
||
def tb_callback(): | ||
log_dir = 'logs/fit/' + datetime.datetime.now().strftime('%Y%m%d-%H%M%S') | ||
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, | ||
histogram_freq=1) | ||
return tensorboard_callback | ||
|
||
|
||
def train_model(NUM_EPOCHS=100): | ||
history = model.fit(train_ds, | ||
validation_data=val_ds, | ||
epochs=NUM_EPOCHS, | ||
callbacks=[tb_callback()]) | ||
return model, history | ||
|
||
|
||
def predict(model, image_batch, plot_preds=False): | ||
predicted_batch = model.predict(image_batch) | ||
predicted_id = tf.math.argmax(predicted_batch, axis=-1) | ||
predicted_label_batch = class_names[predicted_id] | ||
logger.debug(predicted_label_batch) | ||
|
||
if plot_preds: | ||
plt.figure(figsize=(10, 9)) | ||
plt.subplots_adjust(hspace=0.5) | ||
for n in range(30): | ||
plt.subplot(6, 5, n + 1) | ||
plt.imshow(image_batch[n]) | ||
plt.title(predicted_label_batch[n].title()) | ||
plt.axis('off') | ||
_ = plt.suptitle('Model predictions') | ||
|
||
|
||
def export_model(model, class_names): | ||
export_path = f'saved_models/{int(time.time())}' | ||
model.save(export_path) | ||
np.save(f'saved_models/{int(time.time())}/class_names.npy', class_names) | ||
logger.debug(export_path) | ||
|
||
|
||
def predict_from_exported(export_path, | ||
class_names, | ||
image_batch, | ||
plot_preds=False): | ||
reloaded = tf.keras.models.load_model(export_path) | ||
reloaded_result_batch = reloaded.predict(image_batch) | ||
reloaded_predicted_id = tf.math.argmax(reloaded_result_batch, axis=-1) | ||
reloaded_predicted_label_batch = class_names[reloaded_predicted_id] | ||
logger.debug(f'Prediction: {reloaded_predicted_label_batch}') | ||
|
||
if plot_preds: | ||
plt.figure(figsize=(10, 9)) | ||
plt.subplots_adjust(hspace=0.5) | ||
for n in range(30): | ||
plt.subplot(6, 5, n + 1) | ||
plt.imshow(image_batch[n]) | ||
plt.title(reloaded_predicted_label_batch[n].title()) | ||
plt.axis('off') | ||
_ = plt.suptitle('Model predictions') | ||
return reloaded_predicted_label_batch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,11 @@ | ||
requests>=2.27.1 | ||
loguru>=0.5.3 | ||
loguru>=0.6.0 | ||
tqdm>=4.62.3 | ||
numpy>=1.22.3 | ||
matplotlib>=3.5.1 | ||
numpy>=1.21.4 | ||
pandas>=1.3.4 | ||
tensorflow>=2.7.0 | ||
tensorflow_hub>=0.12.0 | ||
python-dotenv>=0.19.2 | ||
opencv-python>=4.5.5.64 | ||
pillow>=9.0.1 | ||
pandas>=1.4.1 |