Zero-shot image classification model for Russian language
RuCLIP (Russian Contrastive Language–Image Pretraining) is a multimodal model for obtaining images and text similarities and rearranging captions and pictures. RuCLIP builds on a large body of work on zero-shot transfer, computer vision, natural language processing and multimodal learning. This repo has the prototypes model of OpenAI CLIP's Russian version following this paper.
- ruclip-vit-base-patch32-224 🤗
- ruclip-vit-base-patch16-224 🤗
- ruclip-vit-large-patch14-224 🤗
- ruclip-vit-base-patch32-384 🤗
- ruclip-vit-large-patch14-336 ☁️SberCloud only
- ruclip-vit-base-patch16-384 ☁️SberCloud only ️
pip install ruclip==0.0.1rc7
import ruclip
device = 'cuda'
clip, processor = ruclip.load('ruclip-vit-large-patch14-224', device=device)
import torch
import base64
import requests
from PIL import Image
from io import BytesIO
# prepare images
bs4_urls = requests.get('https://raw.githubusercontent.com/sberbank-ai/ru-dolph/master/pics/pipelines/cats_vs_dogs_bs4.json').json()
images = [Image.open(BytesIO(base64.b64decode(bs4_url))) for bs4_url in bs4_urls]
# prepare classes
classes = ['кошка', 'собака']
templates = ['{}', 'это {}', 'на картинке {}', 'это {}, домашнее животное']
# predict
predictor = ruclip.Predictor(clip, processor, device, bs=8, templates=templates)
with torch.no_grad():
text_latents = predictor.get_text_latents(classes)
pred_labels = predictor.run(images, text_latents)
# show results
f, ax = plt.subplots(2,4, figsize=(12,6))
for i, (pil_img, pred_label) in enumerate(zip(images, pred_labels)):
ax[i//4, i%4].imshow(pil_img)
ax[i//4, i%4].set_title(classes[pred_label])
train = CIFAR100(root, download=True, train=True)
test = CIFAR100(root, download=True, train=False)
with torch.no_grad():
X_train = predictor.get_image_latents((pil_img for pil_img, _ in train)).cpu().numpy()
X_test = predictor.get_image_latents((pil_img for pil_img, _ in test)).cpu().numpy()
y_train, y_test = np.array(train.targets), np.array(test.targets)
clf = LogisticRegression(solver='lbfgs', penalty='l2', max_iter=1000, verbose=1)
clf.fit(X_train, y_train)
y_pred = clf.predict(X_test)
accuracy = np.mean((y_test == y_pred).astype(np.float)) * 100.
print(f"Accuracy = {accuracy:.3f}")
>>> Accuracy = 75.680
We have evaluated the performance zero-shot image classification on the following datasets:
Dataset | ruclip-vit-base-patch32-224 | ruclip-vit-base-patch16-224 | ruclip-vit-large-patch14-224 | ruclip-vit-base-patch32-384 | ruclip-vit-large-patch14-336 ☁️SberCloud only | ruclip-vit-base-patch16-384 ☁️SberCloud only |
---|---|---|---|---|---|---|
Food101, acc | 0.505 | 0.552 | 0.597 | 0.642 | 0.712 💥 | 0.689 |
CIFAR10, acc | 0.818 | 0.810 | 0.878 | 0.862 | 0.906 💥 | 0.845 |
CIFAR100, acc | 0.504 | 0.496 | 0.511 | 0.529 | 0.591 💥 | 0.569 |
Birdsnap, acc | 0.115 | 0.117 | 0.172 | 0.161 | 0.213 💥 | 0.195 |
SUN397, acc | 0.452 | 0.462 | 0.484 | 0.510 | 0.523 💥 | 0.521 |
Stanford Cars, acc | 0.433 | 0.487 | 0.559 | 0.572 | 0.659 💥 | 0.626 |
DTD, acc | 0.380 | 0.401 | 0.370 | 0.390 | 0.408 | 0.421 💥 |
MNIST, acc | 0.447 | 0.464 | 0.337 | 0.404 | 0.242 | 0.478 💥 |
STL10, acc | 0.932 | 0.932 | 0.934 | 0.946 | 0.956 | 0.964 💥 |
PCam, acc | 0.501 | 0.505 | 0.520 | 0.506 | 0.554 💥 | 0.501 |
CLEVR, acc | 0.148 | 0.128 | 0.152 | 0.188 💥 | 0.142 | 0.132 |
Rendered SST2, acc | 0.489 | 0.527 | 0.529 | 0.508 | 0.539 💥 | 0.525 |
ImageNet, acc | 0.375 | 0.401 | 0.426 | 0.451 | 0.488 💥 | 0.482 |
FGVC Aircraft, mean-per-class | 0.033 | 0.043 | 0.046 | 0.053 | 0.075 💥 | 0.046 |
Oxford Pets, mean-per-class | 0.560 | 0.595 | 0.604 | 0.587 | 0.546 | 0.635 💥 |
Caltech101, mean-per-class | 0.786 | 0.775 | 0.777 | 0.834 | 0.835 💥 | 0.835 💥 |
Flowers102, mean-per-class | 0.401 | 0.388 | 0.455 | 0.449 | 0.517 💥 | 0.452 |
Hateful Memes, roc-auc | 0.564 💥 | 0.516 | 0.530 | 0.537 | 0.519 | 0.543 |
And for linear-prob evaluation:
Dataset | ruclip-vit-base-patch32-224 | ruclip-vit-base-patch16-224 | ruclip-vit-large-patch14-224 | ruclip-vit-base-patch32-384 | ruclip-vit-large-patch14-336 | ruclip-vit-base-patch16-384 |
---|---|---|---|---|---|---|
Food101 | 0.765 | 0.827 | 0.840 | 0.851 | 0.896 💥 | 0.890 |
CIFAR10 | 0.917 | 0.922 | 0.927 | 0.934 | 0.943 💥 | 0.942 |
CIFAR100 | 0.716 | 0.739 | 0.734 | 0.745 | 0.770 | 0.773 💥 |
Birdsnap | 0.347 | 0.503 | 0.567 | 0.434 | 0.609 | 0.612 💥 |
SUN397 | 0.683 | 0.721 | 0.731 | 0.721 | 0.759 💥 | 0.758 |
Stanford Cars | 0.697 | 0.776 | 0.797 | 0.766 | 0.831 | 0.840 💥 |
DTD | 0.690 | 0.734 | 0.711 | 0.703 | 0.731 | 0.749 💥 |
MNIST | 0.963 | 0.974 💥 | 0.949 | 0.965 | 0.949 | 0.971 |
STL10 | 0.957 | 0.962 | 0.973 | 0.968 | 0.981 💥 | 0.974 |
PCam | 0.827 | 0.823 | 0.791 | 0.835 | 0.807 | 0.846 💥 |
CLEVR | 0.356 | 0.360 | 0.358 | 0.308 | 0.318 | 0.378 💥 |
Rendered SST2 | 0.603 | 0.655 | 0.651 | 0.651 | 0.637 | 0.661 💥 |
FGVC Aircraft | 0.254 | 0.312 | 0.290 | 0.283 | 0.341 | 0.362 💥 |
Oxford Pets | 0.774 | 0.820 | 0.819 | 0.730 | 0.753 | 0.856 💥 |
Caltech101 | 0.904 | 0.917 | 0.914 | 0.922 | 0.937 💥 | 0.932 |
HatefulMemes | 0.545 | 0.568 | 0.563 | 0.581 | 0.585 💥 | 0.578 |
Also, we have created speed comparison based on CIFAR100 dataset using Nvidia-V100 for evaluation:
ruclip-vit-base-patch32-224 | ruclip-vit-base-patch16-224 | ruclip-vit-large-patch14-224 | ruclip-vit-base-patch32-384 | ruclip-vit-large-patch14-336 | ruclip-vit-base-patch16-384 | |
---|---|---|---|---|---|---|
iter/sec | 308.84 💥 | 155.35 | 49.95 | 147.26 | 22.11 | 61.79 |