Skip to content

Commit

Permalink
Merge pull request #5 from chenxwh/main
Browse files Browse the repository at this point in the history
Add Replicate demo and API
  • Loading branch information
mv-lab authored Feb 1, 2024
2 parents bc395f6 + ffc67b8 commit 8a0684e
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 2 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
[![arXiv](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)](https://arxiv.org/abs/2401.16468)
<a href="https://colab.research.google.com/drive/1OrTvS-i6uLM2Y8kIkq8ZZRwEQxQFchfq?usp=sharing"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="google colab logo"></a>
[![Hugging Face](https://img.shields.io/badge/Demo-%F0%9F%A4%97%20Hugging%20Face-blue)](https://huggingface.co/spaces/marcosv/InstructIR)
[![Replicate](https://img.shields.io/badge/Demo-%F0%9F%9A%80%20Replicate-blue)](https://replicate.com/mv-lab)
[![Replicate](https://img.shields.io/badge/Demo-%F0%9F%9A%80%20Replicate-blue)](https://replicate.com/mv-lab/instructir)
[![Paper page](https://huggingface.co/datasets/huggingface/badges/resolve/main/paper-page-sm.svg)](https://huggingface.co/papers/2401.16468)
[![Model on HF](https://huggingface.co/datasets/huggingface/badges/resolve/main/model-on-hf-sm.svg)](https://huggingface.co/marcosv/InstructIR)
[![Open in Spaces](https://huggingface.co/datasets/huggingface/badges/resolve/main/open-in-hf-spaces-md.svg)](https://huggingface.co/spaces/marcosv/InstructIR)
Expand All @@ -31,7 +31,7 @@ Image restoration is a fundamental problem that involves recovering a high-quali

### TODO / News 🔥

- [ ] Replicate Demo
- [x] Replicate Demo
- [ ] Upload all test results for comparisons (ETA 1st Feb)
- [x] Upload models to HF 🤗 [(download the models here)](https://huggingface.co/marcosv/InstructIR)
- [x] 🤗 [Hugging Face Demo](https://huggingface.co/spaces/marcosv/InstructIR) try it now
Expand Down
19 changes: 19 additions & 0 deletions cog.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Configuration for Cog ⚙️
# Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md

build:
gpu: true
system_packages:
- "libgl1-mesa-glx"
- "libglib2.0-0"
python_version: "3.11"
python_packages:
- torch==2.0.1
- transformers==4.37.2
- PyYAML==6.0.1
- Pillow==10.2.0
- sentence-transformers==2.3.0
- opencv-python==4.9.0.80
- matplotlib==3.8.2
- imageio==2.33.1
predict: "predict.py:Predictor"
113 changes: 113 additions & 0 deletions predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Prediction interface for Cog ⚙️
# https://github.com/replicate/cog/blob/main/docs/python.md

import os
import numpy as np
import yaml
import torch
from cog import BasePredictor, Input, Path

from utils import *
from models import instructir
from text.models import LanguageModel, LMHead

os.environ["TOKENIZERS_PARALLELISM"] = "false"


class Predictor(BasePredictor):
def setup(self) -> None:
"""Load the model into memory to make running multiple predictions efficient"""

LM_MODEL = "models/lm_instructir-7d.pt"
MODEL_NAME = "models/im_instructir-7d.pt"
device = torch.device("cpu")

with open(os.path.join("configs/eval5d.yml"), "r") as f:
config = yaml.safe_load(f)

cfg = dict2namespace(config)

torch.backends.cudnn.deterministic = True
self.model = instructir.create_model(
input_channels=cfg.model.in_ch,
width=cfg.model.width,
enc_blks=cfg.model.enc_blks,
middle_blk_num=cfg.model.middle_blk_num,
dec_blks=cfg.model.dec_blks,
txtdim=cfg.model.textdim,
)

self.model = self.model.to(device)
print("IMAGE MODEL CKPT:", MODEL_NAME)
self.model.load_state_dict(
torch.load(MODEL_NAME, map_location="cpu"), strict=True
)

# Initialize the LanguageModel class
LMODEL = cfg.llm.model
self.language_model = LanguageModel(model=LMODEL)
self.lm_head = LMHead(
embedding_dim=cfg.llm.model_dim,
hidden_dim=cfg.llm.embd_dim,
num_classes=cfg.llm.nclasses,
)
self.lm_head = self.lm_head # .to(device)

print("LMHEAD MODEL CKPT:", LM_MODEL)
self.lm_head.load_state_dict(
torch.load(LM_MODEL, map_location="cpu"), strict=True
)
print("Loaded weights!")

def predict(
self,
image: Path = Input(description="Input image."),
prompt: str = Input(description="Input prompt."),
seed: int = Input(
description="Random seed. Leave blank to randomize the seed", default=None
),
) -> Path:
"""Run a single prediction on the model"""
if seed is None:
seed = int.from_bytes(os.urandom(2), "big")
print(f"Using seed: {seed}")
seed_everything(SEED=seed)

torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()

image = load_img(str(image))
out_image = process_img(
image, prompt, self.language_model, self.model, self.lm_head
)

out_path = "/tmp/out.png"
saveImage(out_path, out_image)

return Path(out_path)


def process_img(image, prompt, language_model, model, lm_head):
"""
Given an image and a prompt, we run InstructIR to restore the image following the human prompt.
image: RGB image as numpy array normalized to [0,1]
prompt: plain python string,
returns the restored image as numpy array.
"""

# Convert the image to tensor
y = torch.Tensor(image).permute(2, 0, 1).unsqueeze(0)

# Get the text embedding (and predicted degradation class)
lm_embd = language_model(prompt)
lm_embd = lm_embd # .to(device)
text_embd, deg_pred = lm_head(lm_embd)

# Forward pass: Paper Figure 2
x_hat = model(y, text_embd)

# convert the restored image <x_hat> into a np array
restored_img = x_hat[0].permute(1, 2, 0).cpu().detach().numpy()
restored_img = np.clip(restored_img, 0.0, 1.0)
return restored_img

0 comments on commit 8a0684e

Please sign in to comment.