Skip to content

Commit

Permalink
Merge pull request THUDM#13 from tongyx361/main
Browse files Browse the repository at this point in the history
improve memory management features
  • Loading branch information
xujz18 authored May 6, 2023
2 parents 45d6354 + 80bc00c commit 9a89c48
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 2 deletions.
10 changes: 9 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,14 @@ https://user-images.githubusercontent.com/98524878/233889490-5c4a062f-bb5e-4179-
<img src="https://user-images.githubusercontent.com/98524878/233829640-12190bff-f62b-4160-b05d-29624fa83677.jpg" width="700px">
</p>

#### Other Features

##### Memory Management

- ImageReward model will not be loaded **until first script run**.
- **"Reload UI"** will not reload the model nor unload it, but **reuse**s the currently loaded model (if it exists).
- A **"Unload Model"** button is provided to manually unload the currently loaded model.

## Reproduce Experiments in Table 2

<p align="center">
Expand All @@ -165,7 +173,7 @@ TODO

```
@misc{xu2023imagereward,
title={ImageReward: Learning and Evaluating Human Preferences for Text-to-Image Generation},
title={ImageReward: Learning and Evaluating Human Preferences for Text-to-Image Generation},
author={Jiazheng Xu and Xiao Liu and Yuchen Wu and Yuxuan Tong and Qinkai Li and Ming Ding and Jie Tang and Yuxiao Dong},
year={2023},
eprint={2304.05977},
Expand Down
90 changes: 89 additions & 1 deletion demo/sdwebui/image_reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,16 @@
import gradio as gr

from modules import sd_samplers, shared
<<<<<<< HEAD
from modules.processing import (
Processed,
process_images,
StableDiffusionProcessing,
create_infotext,
)
=======
from modules.processing import Processed, process_images, StableDiffusionProcessing, create_infotext
>>>>>>> 45d635454c2dbc31fa1e75a4ccb8dd1031183974
import modules.images as images
from modules.shared import opts, cmd_opts, state

Expand All @@ -12,6 +21,15 @@
from pathlib import Path
import ImageReward as reward

<<<<<<< HEAD

def unload_image_reward_model():
del shared.image_reward_model


class Script(scripts.Script):
def title(self):
=======
# load the model
# by default, it will:
# 1. set the device to cuda if available
Expand All @@ -22,12 +40,57 @@
class Script(scripts.Script):

def title(self):
>>>>>>> 45d635454c2dbc31fa1e75a4ccb8dd1031183974
return "ImageReward - generate human preference scores"

def show(self, is_txt2img):
return True

def ui(self, is_txt2img):
<<<<<<< HEAD
with gr.Blocks():
with gr.Row():
gr.Markdown(
value="**Tip**: It will take a little time to **load** the ImageReward model before the first generation."
)
with gr.Row():
with gr.Column():
filter_out_low_scores = gr.Checkbox(
value=False, label="Filter out images with low scores"
)
with gr.Column():
lower_score_limit = gr.Textbox(value=0, label="Lower score limit")
with gr.Row():
gr.Markdown(
value="ImageReward model takes about **1,600 MB** of memory."
)
with gr.Row():
unload_button = gr.Button(value="Unload Model")
unload_button.click(unload_image_reward_model)

return [filter_out_low_scores, lower_score_limit]

def run(self, p, filter_out_low_scores, lower_score_limit):
try:
shared.image_reward_model # if loaded, do nothing
except AttributeError:
# load the model
# by default, it will:
# 1. set the device to cuda if available
# 2. download the model and cache it in `~/.cache/` if model is not found
# you can alse configure the device and cache dir by passing in the arguments
shared.image_reward_model = reward.load(
"ImageReward-v1.0"
) # using shared to make the model object global among modules

# preprocess parameters
if lower_score_limit != "":
lower_score_limit = float(lower_score_limit)

# generate images
proc = process_images(p)

=======
filter_out_low_scores = gr.Checkbox(value=False, label="Filter out images with low scores")
lower_score_limit = gr.Textbox(value=0, label="Lower score limit")
return [filter_out_low_scores, lower_score_limit]
Expand All @@ -40,10 +103,29 @@ def run(self, p, filter_out_low_scores, lower_score_limit):
# generate images
proc = process_images(p)

>>>>>>> 45d635454c2dbc31fa1e75a4ccb8dd1031183974
# score
gens = proc.images
for img in gens:
with torch.no_grad():
<<<<<<< HEAD
score = shared.image_reward_model.score(p.prompt, img)
img.info["score"] = score
img.info["parameters"] += f"\n ImageReward Score: {score:.4f}"

# filter out images with scores lower than the lower limit
if filter_out_low_scores:
imgs = list(filter(lambda x: x.info["score"] > lower_score_limit, gens))
else:
imgs = gens

# append score to info
infotexts = [img.info["parameters"] for img in imgs]

# sort to score
img_info_list = list(zip(imgs, infotexts))
img_info_list.sort(key=lambda x: x[0].info["score"], reverse=True)
=======
score = model.score(p.prompt, img)
img.info['score'] = score
img.info['parameters'] += f"\n ImageReward Score: {score:.4f}"
Expand All @@ -60,6 +142,7 @@ def run(self, p, filter_out_low_scores, lower_score_limit):
# sort to score
img_info_list = list(zip(imgs, infotexts))
img_info_list.sort(key=lambda x: x[0].info['score'], reverse=True)
>>>>>>> 45d635454c2dbc31fa1e75a4ccb8dd1031183974
imgs, infotexts = list(zip(*img_info_list))

# return Processed object
Expand All @@ -69,5 +152,10 @@ def run(self, p, filter_out_low_scores, lower_score_limit):
info=proc.info,
seed=proc.seed,
infotexts=infotexts,
<<<<<<< HEAD
index_of_first_image=proc.index_of_first_image,
)
=======
index_of_first_image=proc.index_of_first_image
)
)
>>>>>>> 45d635454c2dbc31fa1e75a4ccb8dd1031183974

0 comments on commit 9a89c48

Please sign in to comment.