Converted official JAX models for LiT: Zero-Shot Transfer with Locked-image text Tuning to pytorch.
JAX -> Tensorflow -> ONNX -> Pytorch.
- Image encoder is loaded into pytorch and supports gradients
- Text encoder is not loaded into pytorch and runs via ONNX on cpu
poetry add pytorch-zero-lit
or
pip install pytorch-zero-lit
from lit import LiT
model = LiT()
images = TF.to_tensor(
Image.open("cat.png").convert("RGB").resize((224, 224))
)[None]
texts = [
"a photo of a cat",
"a photo of a dog",
"a photo of a bird",
"a photo of a fish",
]
image_encodings = model.encode_images(images)
text_encodings = model.encode_texts(texts)
cosine_similarity = model.cosine_similarity(image_encodings, text_encodings)