Skip to content

Commit

Permalink
update 3.9.4
Browse files Browse the repository at this point in the history
  • Loading branch information
lyhue1991 committed Aug 19, 2023
1 parent b00c31e commit 9c6f1c7
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 113 deletions.
5 changes: 2 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,11 +113,10 @@ dfhistory=model.fit(train_data=dl_train,
val_data=dl_val,
epochs=20,
patience=3,
ckpt_path='checkpoint.pt',
ckpt_path='checkpoint',
monitor="val_acc",
mode="max",
plot=True,

plot=True
)

```
Expand Down
55 changes: 21 additions & 34 deletions push2github.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@
"copying torchkeras/chat/chatgpt.py -> torchkeras-3.9.4/torchkeras/chat\n",
"copying torchkeras/chat/chatllm.py -> torchkeras-3.9.4/torchkeras/chat\n",
"copying torchkeras/chat/conversations.py -> torchkeras-3.9.4/torchkeras/chat\n",
"copying torchkeras/chat/data.py -> torchkeras-3.9.4/torchkeras/chat\n",
"copying torchkeras/chat/stream_generate.py -> torchkeras-3.9.4/torchkeras/chat\n",
"copying torchkeras/models/__init__.py -> torchkeras-3.9.4/torchkeras/models\n",
"copying torchkeras/models/resnet.py -> torchkeras-3.9.4/torchkeras/models\n",
Expand Down Expand Up @@ -125,6 +126,7 @@
"copying torchkeras/chat/stream_generate.py -> build/lib/torchkeras/chat\n",
"copying torchkeras/chat/chatgpt.py -> build/lib/torchkeras/chat\n",
"copying torchkeras/chat/conversations.py -> build/lib/torchkeras/chat\n",
"copying torchkeras/chat/data.py -> build/lib/torchkeras/chat\n",
"creating build/lib/torchkeras/models\n",
"copying torchkeras/models/ssd.py -> build/lib/torchkeras/models\n",
"copying torchkeras/models/unet.py -> build/lib/torchkeras/models\n",
Expand All @@ -151,6 +153,7 @@
"copying build/lib/torchkeras/chat/stream_generate.py -> build/bdist.macosx-11.0-arm64/wheel/torchkeras/chat\n",
"copying build/lib/torchkeras/chat/chatgpt.py -> build/bdist.macosx-11.0-arm64/wheel/torchkeras/chat\n",
"copying build/lib/torchkeras/chat/conversations.py -> build/bdist.macosx-11.0-arm64/wheel/torchkeras/chat\n",
"copying build/lib/torchkeras/chat/data.py -> build/bdist.macosx-11.0-arm64/wheel/torchkeras/chat\n",
"copying build/lib/torchkeras/__init__.py -> build/bdist.macosx-11.0-arm64/wheel/torchkeras\n",
"creating build/bdist.macosx-11.0-arm64/wheel/torchkeras/models\n",
"copying build/lib/torchkeras/models/ssd.py -> build/bdist.macosx-11.0-arm64/wheel/torchkeras/models\n",
Expand Down Expand Up @@ -237,6 +240,7 @@
"adding 'torchkeras/chat/chatgpt.py'\n",
"adding 'torchkeras/chat/chatllm.py'\n",
"adding 'torchkeras/chat/conversations.py'\n",
"adding 'torchkeras/chat/data.py'\n",
"adding 'torchkeras/chat/stream_generate.py'\n",
"adding 'torchkeras/models/__init__.py'\n",
"adding 'torchkeras/models/resnet.py'\n",
Expand Down Expand Up @@ -316,7 +320,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 2,
"id": "7a76c373",
"metadata": {},
"outputs": [],
Expand All @@ -328,25 +332,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"id": "ee26e2f1-9ecc-478f-9224-e6690188de94",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"git@hf.co: Permission denied (publickey).\n"
]
}
],
"source": [
"!ssh -T git@hf.co "
]
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 3,
"id": "7a23b162",
"metadata": {},
"outputs": [],
Expand All @@ -356,16 +342,17 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 4,
"id": "cdb41033",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[master 4ff5bcb] update 3.9.4\n",
" 2 files changed, 16 insertions(+), 6 deletions(-)\n"
"[master b00c31e] update 3.9.4\n",
" 5 files changed, 454 insertions(+), 43 deletions(-)\n",
" create mode 100644 torchkeras/chat/data.py\n"
]
}
],
Expand Down Expand Up @@ -415,7 +402,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 5,
"id": "12e49933",
"metadata": {},
"outputs": [
Expand All @@ -426,12 +413,12 @@
"Enumerating objects: 16, done.\n",
"Counting objects: 100% (16/16), done.\n",
"Delta compression using up to 8 threads\n",
"Compressing objects: 100% (10/10), done.\n",
"Writing objects: 100% (10/10), 1.12 KiB | 1.12 MiB/s, done.\n",
"Total 10 (delta 8), reused 0 (delta 0), pack-reused 0\n",
"remote: Resolving deltas: 100% (8/8), completed with 6 local objects.\u001b[K\n",
"Compressing objects: 100% (9/9), done.\n",
"Writing objects: 100% (9/9), 3.29 KiB | 3.29 MiB/s, done.\n",
"Total 9 (delta 6), reused 0 (delta 0), pack-reused 0\n",
"remote: Resolving deltas: 100% (6/6), completed with 6 local objects.\u001b[K\n",
"To github.com:lyhue1991/torchkeras.git\n",
" 7b1a530..4ff5bcb master -> master\n"
" 4ff5bcb..b00c31e master -> master\n"
]
}
],
Expand All @@ -451,7 +438,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 6,
"id": "1e3cab48-5eff-4557-9cf0-5deb42441e91",
"metadata": {},
"outputs": [
Expand All @@ -462,12 +449,12 @@
"Enumerating objects: 16, done.\n",
"Counting objects: 100% (16/16), done.\n",
"Delta compression using up to 8 threads\n",
"Compressing objects: 100% (10/10), done.\n",
"Writing objects: 100% (10/10), 1.12 KiB | 1.12 MiB/s, done.\n",
"Total 10 (delta 8), reused 0 (delta 0), pack-reused 0\n",
"Compressing objects: 100% (9/9), done.\n",
"Writing objects: 100% (9/9), 3.29 KiB | 3.29 MiB/s, done.\n",
"Total 9 (delta 6), reused 0 (delta 0), pack-reused 0\n",
"remote: Powered by \u001b[01;33mGITEE.COM \u001b[0m[\u001b[01;35mGNK-6.4\u001b[0m]\u001b[0m\u001b[K\n",
"To https://gitee.com/Python_Ai_Road/torchkeras\n",
" 7b1a530..4ff5bcb master -> master\n"
" 4ff5bcb..b00c31e master -> master\n"
]
}
],
Expand Down
3 changes: 1 addition & 2 deletions torchkeras/chat/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from .chatgpt import ChatGPT
from .chatglm import ChatGLM
from .conversations import get_conv_template, conv_templates
from .chatllm import ChatLLM
from .data import data_collator
from .chatllm import ChatLLM
19 changes: 9 additions & 10 deletions torchkeras/chat/chatllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import sys
from copy import deepcopy
from .conversations import conv_templates, get_conv_template
from .data import build_inputs_labels
from .text2ids import build_inputs_labels

#chat tool for chatglm2-6b,baichuan-13b,internlm-chat-7b,qwen-7b-chat and more...
class ChatLLM:
Expand All @@ -19,8 +19,11 @@ def __init__(self,model,tokenizer,
if not self.tokenizer.eos_token_id:
self.tokenizer.eos_token_id = (model.config.eos_token_id
or model.generation_config.eos_token_id)

self.model_type = model_type if model_type else self.get_model_type()
conv = get_conv_template(self.model_type)
self.conv_template = conv

self.stop_words_ids = [[w] for w in conv.stop_token_ids] if conv.stop_token_ids else []
self.model.generation_config.stop_words_ids = self.stop_words_ids
self.model.generation_config.max_new_tokens = max_new_tokens
Expand Down Expand Up @@ -75,15 +78,10 @@ def build_messages(cls,query=None,history=None,system=None):
return messages

def build_conversations(self,messages):
model = self.model
if not hasattr(model,'conv_template'):
model_type = self.get_model_type()
model.conv_template =get_conv_template(model_type)
conv = deepcopy(model.conv_template)
conv = deepcopy(self.conv_template)
msgs_sys = [d for d in messages if d['role']=='system']
if msgs_sys:
conv.set_system_message(msgs_sys[0]['content'])

for d in messages:
if d['role']=='user':
conv.append_message(conv.roles[0], d['content'])
Expand All @@ -102,7 +100,7 @@ def build_prompt(self,messages):
def build_inputs_labels(self,messages,multi_rounds=True):
conv = self.build_conversations(messages)
inputs,labels = build_inputs_labels(
conv,self.tokenizer, multi_rounds=multi_rounds)
conv, self.tokenizer, multi_rounds=multi_rounds)
return inputs,labels

def chat(self, messages, stream=False, generation_config=None):
Expand Down Expand Up @@ -141,7 +139,9 @@ def chat(self, messages, stream=False, generation_config=None):
from .stream_generate import NewGenerationMixin, StreamGenerationConfig
model.__class__.generate = NewGenerationMixin.generate
model.__class__.sample_stream = NewGenerationMixin.sample_stream
stream_config = StreamGenerationConfig(**generation_config.to_dict(),do_stream=True)
config_dic = generation_config.to_dict()
config_dic.update({'do_stream':True})
stream_config = StreamGenerationConfig(**config_dic)

def stream_generator():
outputs = []
Expand Down Expand Up @@ -176,7 +176,6 @@ def __call__(self,query):
self.history.append((query,response))
return response


def register_magic(self):
import IPython
from IPython.core.magic import (Magics, magics_class, line_magic,
Expand Down
Loading

0 comments on commit 9c6f1c7

Please sign in to comment.