Skip to content

Commit

Permalink
Merge pull request THUDM#215 from zRzRzRzRzRzRzR/main
Browse files Browse the repository at this point in the history
update streamlit demo, and gradio illustration
  • Loading branch information
wenyihong authored Dec 19, 2023
2 parents c6a4625 + 47ba772 commit acb52b1
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 20 deletions.
22 changes: 15 additions & 7 deletions basic_demo/web_demo.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,25 @@
"""
This script is a simple web demo of the CogVLM and CogAgent models, designed for easy and quick demonstrations.
For a more sophisticated user interface, users are encouraged to refer to the 'composite_demo',
which is built with a more aesthetically pleasing Streamlit framework.
Usage:
- Use the interface to upload images and enter text prompts to interact with the models.
Requirements:
- Gradio (only 3.x,4.x is not support) and other necessary Python dependencies must be installed.
- Proper model checkpoints should be accessible as specified in the script.
Note: This demo is ideal for a quick showcase of the CogVLM and CogAgent models. For a more comprehensive and interactive
experience, refer to the 'composite_demo'.
"""
import gradio as gr
import os, sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from PIL import Image
import base64
import json
import requests
import base64
import hashlib
import torch
import time
import re
import argparse
from sat.model.mixins import CachedAutoregressiveMixin
from sat.mpu import get_model_parallel_world_size
from sat.model import AutoModel
Expand Down
24 changes: 16 additions & 8 deletions composite_demo/main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
"""
This is a demo using the chat version about CogAgent and CogVLM in WebDEMO
Make sure you have installed the vicuna-7b-v1.5 tokenizer model (https://huggingface.co/lmsys/vicuna-7b-v1.5), a full checkpoint of vicuna-7b-v1.5 LLM is not required.
Mention that only one image can be processed in a conversation, which means you cannot replace or insert another image during the conversation.
Make sure you have installed the vicuna-7b-v1.5 tokenizer model (https://huggingface.co/lmsys/vicuna-7b-v1.5),
and a full checkpoint of vicuna-7b-v1.5 LLM is not required.
Mention that only one image can be processed in a conversation, which means you cannot replace or insert another image
during the conversation.
The models_info parameter is explained as follows
Expand All @@ -11,13 +15,17 @@
vlm_grounding: Use CogVLM-grounding-17B model to complete the Grounding task
Web Demo user operation logic is as follows:
CogVLM -> grounding? - yes -> CogVLM-grounding-17B
- no -> CogVLM-chat-17B
CogVLM-Chat -> grounding? - yes -> Choose a template -> CogVLM-grounding-17B
- no -> CogVLM-chat-17B (without grounding)
CogAgent-Chat -> CogAgent-chat-18B (Only QA,without Grounding)
CogAgent -> CogAgent-chat-18B
-> Choose a template -> grounding? - yes -> prompt + (with grounding)
- no -> prompt
CogAgent-Agent -> CogAgent-chat-18B
-> Choose a template -> grounding? - yes -> prompt + (with grounding)
- no -> prompt
CogAgent-vqa-hf are not included in this demo, but you can use it in the same way as CogAgent-chat-18B
and used it in CogAgent-Chat
"""

import streamlit as st
Expand Down Expand Up @@ -68,7 +76,7 @@ class Mode(str, Enum):
label_visibility='hidden',
)
grounding = False
selected_template_grounding_cogvlm = None
selected_template_grounding_cogvlm = ""

if tab != Mode.CogAgent_Chat.value:
with st.sidebar:
Expand Down
13 changes: 8 additions & 5 deletions openai_demo/openai_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from PIL import Image
from io import BytesIO


MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/cogvlm-chat-hf')
TOKENIZER_PATH = os.environ.get("TOKENIZER_PATH", 'lmsys/vicuna-7b-v1.5')
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
Expand Down Expand Up @@ -88,7 +87,7 @@ class ChatMessageInput(BaseModel):
name: Optional[str] = None


class ChatMessageResponse(BaseModel): # 模型回复的字段
class ChatMessageResponse(BaseModel):
role: Literal["assistant"]
content: str = None
name: Optional[str] = None
Expand Down Expand Up @@ -140,7 +139,7 @@ async def list_models():
An endpoint to list available models. It returns a list of model cards.
This is useful for clients to query and understand what models are available for use.
"""
model_card = ModelCard(id="cogvlm-chat-17b") # can be replaced by your model id like cogagent-chat-18b
model_card = ModelCard(id="cogvlm-chat-17b") # can be replaced by your model id like cogagent-chat-18b
return ModelList(data=[model_card])


Expand Down Expand Up @@ -301,7 +300,6 @@ def generate_stream_cogvlm(model: PreTrainedModel, tokenizer: PreTrainedTokenize

logger.debug(f"==== request ====\n{query}")

# only can slove the latest picture
input_by_model = model.build_conversation_input_ids(tokenizer, query=query, history=history,
images=[image_list[-1]])
inputs = {
Expand All @@ -314,7 +312,12 @@ def generate_stream_cogvlm(model: PreTrainedModel, tokenizer: PreTrainedTokenize
inputs['cross_images'] = [[input_by_model['cross_images'][0].to(DEVICE).to(torch_type)]]

input_echo_len = len(inputs["input_ids"][0])
streamer = TextIteratorStreamer(tokenizer=tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
streamer = TextIteratorStreamer(
tokenizer=tokenizer,
timeout=60.0,
skip_prompt=True,
skip_special_tokens=True
)
gen_kwargs = {
"repetition_penalty": repetition_penalty,
"max_new_tokens": max_new_tokens,
Expand Down

0 comments on commit acb52b1

Please sign in to comment.