Skip to content

Commit

Permalink
update 3.9.3
Browse files Browse the repository at this point in the history
  • Loading branch information
lyhue1991 committed Jul 25, 2023
1 parent 357f8ba commit 4f3bdf5
Show file tree
Hide file tree
Showing 9 changed files with 311 additions and 399 deletions.
287 changes: 148 additions & 139 deletions 01,kerasmodel_example.ipynb

Large diffs are not rendered by default.

1 change: 0 additions & 1 deletion README_en.md
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,6 @@ KerasModel. Here are some examples.
|**NLP**|||
|TextClassification——BERT 🔥🔥| transformers | [BERT](./examples/BERT——transformers.ipynb) |
|TokenClassification——BERT | transformers | [BERT_NER](./examples/BERT_NER——transformers.ipynb) |
|FinetuneLLM——ChatGLM2 🔥🔥🔥| transformers | [ChatGLM2](./examples/ChatGLM2——transformers.ipynb) |
|FinetuneLLM——ChatGLM2_LoRA 🔥🔥🔥| transformers,peft | [ChatGLM2_LoRA](./examples/ChatGLM2_LoRA——transformers.ipynb) |
|FinetuneLLM——ChatGLM2_AdaLoRA 🔥🔥🔥| transformers,peft | [ChatGLM2_AdaLoRA](./examples/ChatGLM2_AdaLoRA——transformers.ipynb) |
|FinetuneLLM——ChatGLM2_QLoRA 🔥🔥🔥| transformers | [ChatGLM2_QLoRA_Kaggle](./examples/ChatGLM2_QLoRA_Kaggle——transformers.ipynb) |
Expand Down
9 changes: 3 additions & 6 deletions kerasmodel_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,22 +97,19 @@ def reset(self):
#model.load_ckpt(ckpt_path) #load trained ckpt and continue training
dfhistory=model.fit(train_data=dl_train,
val_data=dl_val,
epochs=10,
patience=5,
epochs=100,
patience=10,
monitor="val_acc",
mode="max",
ckpt_path=ckpt_path,
plot=True,
wandb=False,
mixed_precision='fp16',
cpu=True
cpu=True
)

model.evaluate(dl_val,quiet=False)




# ### 5, use the model

net = model.net
Expand Down
302 changes: 93 additions & 209 deletions push2github.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def get_version():
description="pytorch❤️keras",
long_description=README,
install_requires=[
'accelerate',
'accelerate>=0.20.3',
'tqdm',
#'torch',
#'torchmetrics',
Expand Down
2 changes: 1 addition & 1 deletion torchkeras/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__="3.9.2"
__version__="3.9.3"

import sys
from .kerasmodel import KerasModel
Expand Down
3 changes: 2 additions & 1 deletion torchkeras/kerascallbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,6 @@ def on_validation_epoch_end(self, model:"KerasModel"):
dfhistory = pd.DataFrame(model.history)
self.progress.update(dfhistory['epoch'].iloc[-1])


def on_fit_end(self, model:"KerasModel"):
dfhistory = pd.DataFrame(model.history)
if dfhistory['epoch'].max()<model.epochs:
Expand All @@ -160,6 +159,8 @@ def __init__(self,figsize = (6,4),
self.figsize = (6,4)
self.save_path = save_path
self.in_jupyter = is_jupyter()
if not self.in_jupyter:
print('\nView dynamic loss/metric plot: \n'+os.path.abspath(self.save_path))

def on_fit_start(self,model: 'KerasModel'):
self.metric = model.monitor.replace('val_','')
Expand Down
24 changes: 9 additions & 15 deletions torchkeras/kerasmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,11 +156,8 @@ def fit(self, train_data, val_data=None, epochs=10, ckpt_path='checkpoint',
callbacks = callbacks if callbacks is not None else []

if bool(plot):
if is_jupyter():
from torchkeras.kerascallbacks import VisProgress
callbacks = [VisProgress()]+callbacks
from torchkeras.kerascallbacks import VisMetric
callbacks = [VisMetric()]+callbacks
from torchkeras.kerascallbacks import VisProgress,VisMetric
callbacks = [VisMetric(),VisProgress()]+callbacks

if wandb!=False:
from torchkeras.kerascallbacks import WandbCallback
Expand All @@ -174,16 +171,12 @@ def fit(self, train_data, val_data=None, epochs=10, ckpt_path='checkpoint',

start_epoch = 1 if self.from_scratch else 0

if quiet is None:
if is_jupyter():
quiet = True
else:
quiet = False
if is_jupyter() or bool(plot) or quiet is None:
quiet = True

quiet_fn = (lambda epoch:quiet) if isinstance(quiet,bool) else (
(lambda epoch:epoch>quiet) if isinstance(quiet,int) else quiet)


for epoch in range(start_epoch,epochs+1):
should_quiet = quiet_fn(epoch)

Expand Down Expand Up @@ -246,15 +239,16 @@ def fit(self, train_data, val_data=None, epochs=10, ckpt_path='checkpoint',
monitor,arr_scores[best_score_idx])))

if len(arr_scores)-best_score_idx>patience:
self.accelerator.print(colorful(
"<<<<<< {} without improvement in {} epoch,""early stopping >>>>>>"
).format(monitor,patience))
break;
break

if self.accelerator.is_local_main_process:
dfhistory = pd.DataFrame(self.history)
[cb.on_fit_end(model = self) for cb in self.callbacks
if hasattr(cb,'on_fit_end')]
if epoch<epochs:
self.accelerator.print(colorful(
"<<<<<< {} without improvement in {} epoch,""early stopping >>>>>> \n"
).format(monitor,patience))
self.net = self.accelerator.unwrap_model(self.net)
self.net.cpu()
self.load_ckpt(ckpt_path)
Expand Down
80 changes: 54 additions & 26 deletions torchkeras/pbar.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@
'''
reference from https://github.com/fastai/fastprogress/
'''
import time
import time,sys
from IPython.display import clear_output, display, HTML
from torchkeras.utils import is_jupyter

from tqdm.utils import _term_move_up
move_up = _term_move_up()

def format_time(t):
"Format `t` (in seconds) to (h):mm:ss"
Expand Down Expand Up @@ -50,18 +53,27 @@ def html_progress_bar(value, total, label='', postfix='', interrupted=False):
</div>
"""

def text_progress_bar(value, total, label='', postfix='', interrupted=False):
bar_style = "🟥" if interrupted else "⬜️" #"🟥","⬜️" "○", "*"
percentage = round(value / total * 50)
finished = "🟩" * (percentage) #"🟩","●"
unfinished = bar_style * (50 - percentage)
bar = "\r{}{} {}".format(finished,unfinished,label)+" "*20+"\t"*50+f"{postfix}"+" "*20+"\t"*50
return bar

class ProgressBar:
update_every,first_its,lt = 0.2,5,'<'
def __init__(self, gen, total=None,
display=True, comment=''):
def __init__(self, gen, total=None, comment=''):
self.gen,self.comment = gen,comment
self.postfix = ''
self.total = None if total=='noinfer' else len(gen) if total is None else total
self.last_v = 0
self.display = display
self.last_v = None
self.display = True
self.in_jupyter = is_jupyter()
self.update(0)


def update(self, val):
if self.last_v is None:
self.on_iter_begin()
Expand All @@ -76,30 +88,44 @@ def update(self, val):
self.wait_for = max(int(self.update_every / (avg_t+1e-8)),1)
self.pred_t = None if self.total is None else avg_t * self.total
self.last_v,self.last_t = val,cur_t
self.update_bar(val)
if self.total is not None and val >= self.total:
self.on_iter_end()
self.last_v = None
self.last_v = self.total
else:
self.update_bar(val)

def on_iter_begin(self):
self.html_code = '\n'.join([html_progress_bar(0, self.total, ""), ""])
display(HTML(html_progress_bar_styles))
self.out = display(HTML(self.html_code), display_id=True)
if self.in_jupyter:
self.html_code = '\n'.join([html_progress_bar(0, self.total, ""), ""])
display(HTML(html_progress_bar_styles))
self.out = display(HTML(self.html_code), display_id=True)
else:
print('\n')

def on_iter_end(self):
total_time = format_time(time.time() - self.start_t)
self.comment = f'[{total_time}]'
if hasattr(self, 'out'):
self.on_update(self.total,self.comment,self.postfix)
self.comment = f'100% [{self.total}/{self.total}] [{total_time}]'
self.on_update(self.total,self.comment,self.postfix,False,1)
self.display = False
if not self.in_jupyter:
print('\n')

def on_update(self, val, comment='', postfix='', interrupted=False):
self.progress = html_progress_bar(val, self.total,comment,postfix,interrupted)
if self.display:
self.out.update(HTML(self.progress))

def on_update(self, val, comment='', postfix='', interrupted=False, up=1):
if not self.display:
return
if self.in_jupyter:
self.progress = html_progress_bar(val, self.total,comment,postfix,interrupted)
self.out.update(HTML(self.progress))
else:
progress = text_progress_bar(val, self.total, comment, postfix, interrupted)
print(move_up*up+progress,end='')

def on_interrupt(self,msg='interrupted'):
self.on_update(self.last_v,self.comment+f'[{msg}]',interrupted=True)

self.on_update(self.last_v,self.comment+f' [{msg}]',self.postfix,interrupted=True,up=1)
self.display = False
if not self.in_jupyter:
print('\n')

def __iter__(self):
if self.total != 0: self.update(0)
try:
Expand All @@ -121,15 +147,17 @@ def update_bar(self, val):
if val ==0:
self.comment = f'0% [0/{self.total}]'
return self.on_update(0, self.comment)
pct = '' if self.total is None else f'{100 * val/self.total:.2f}% '
pct = '' if self.total is None else f'{100 * val/self.total:.2f}%'
tot = '?' if self.total is None else str(self.total)
elapsed_t = self.last_t - self.start_t
remaining_t = '?' if self.pred_t is None else format_time(self.pred_t - elapsed_t)
elapsed_t = format_time(elapsed_t)
self.comment = f'{pct}[{val}/{tot} {elapsed_t}{self.lt}{remaining_t}]'
self.comment = f'{pct} [{val}/{tot}] [{elapsed_t}{self.lt}{remaining_t}]'
self.on_update(val, self.comment, self.postfix)

def set_postfix(self,**kwargs):
if not self.display:
return
postfix = ''
if 'i' in kwargs and 'n' in kwargs:
from tqdm.std import Bar
Expand All @@ -139,14 +167,14 @@ def set_postfix(self,**kwargs):
ratio = i/n
postfix+=format(Bar(ratio,default_len=20))
postfix+=f'{100 * i/n:.2f}%'
postfix+=f' [{i}/{n}] '
postfix+=f' [{i}/{n}]'
if kwargs:
postfix+='['
postfix+=' ['
for i,(key,value) in enumerate(kwargs.items()):
if isinstance(value,float):
postfix = postfix+f'{key}={format_number(value)},'
postfix = postfix+f'{key}={format_number(value)}, '
else:
postfix = postfix+f'{key}={value},'
postfix = postfix[:-1]+']'
postfix = postfix+f'{key}={value}, '
postfix = postfix[:-2]+']'
self.postfix = postfix
self.on_update(self.last_v,self.comment,self.postfix)

0 comments on commit 4f3bdf5

Please sign in to comment.