Skip to content

Commit

Permalink
Add logging for conversations.
Browse files Browse the repository at this point in the history
  • Loading branch information
Donny-Hikari committed Oct 1, 2021
1 parent 78908d7 commit 235bbf5
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 4 deletions.
36 changes: 32 additions & 4 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import os
import numpy as np
import tensorflow as tf
import datetime
import yaml

import model, sample, encoder

Expand Down Expand Up @@ -138,19 +140,40 @@ def news_topic_chatty(enc, sess, context, output, params):
:process_response=None : Function for post-processing of the model responses.
"""
nsamples, batch_size = params['nsamples'], params['batch_size']
verbose = True
verbose = False
user_tag = "Query"
bot_tag = "Response"
chatty_params = settings['chatty']['params']
rolling_prompt = chatty_params['rolling_prompt']
log_conversation = chatty_params['log_conversation']
chat_log_dir = "test-results"

def get_example_conversation():
return f"{user_tag}: Hello. I am {user_tag}. What's your name?\n{bot_tag}: Hello. My name is {bot_tag}."

fixed_prompt = "The news today:\n" + get_news_feed() + "\nNow let's talk about it.\n\n" + get_example_conversation() + "\n\n"
rolling_prompt = 2
chat_log_fn = os.path.join(chat_log_dir, f"{datetime.datetime.now():%Y-%m-%d %H.%M.%S.%f}")

print("="*40 + " Fixed Prompt " + "="*40)
print(fixed_prompt)
print("="*80)

if log_conversation:
os.makedirs(chat_log_dir, exist_ok=True)
assert not os.path.exists(chat_log_fn), f"A log file named {chat_log_fn} already exists."
chat_log_file = open(chat_log_fn, 'w')
chat_log_file.write("# Settings for this run\n")
yaml.dump(settings, chat_log_file)
chat_log_file.write("\n# The conversation\n")
chat_log_file.write(fixed_prompt)

past_memory = []
while True:
raw_prompt = read_chat_prompt()
try:
raw_prompt = read_chat_prompt()
except (EOFError, KeyboardInterrupt) as e:
print(e)
break
raw_text = raw_prompt
if rolling_prompt:
raw_text = '\n\n'.join(past_memory + [raw_text])
Expand Down Expand Up @@ -181,10 +204,15 @@ def get_example_conversation():
preserved_response = text
# print("=" * 80)

if log_conversation:
chat_log_file.write(raw_prompt + preserved_response + "\n\n")

past_memory.append(raw_prompt + preserved_response)
if len(past_memory) > rolling_prompt:
past_memory.pop(0)

chat_log_file.close()

def interact_model(
model_name='124M',
seed=None,
Expand Down Expand Up @@ -290,4 +318,4 @@ def interact_model(
past_memory.pop()

if __name__ == "__main__":
run_model(model_name="124M", callback=news_topic_chatty)
run_model(model_name=settings['model']['name'], callback=news_topic_chatty)
8 changes: 8 additions & 0 deletions settings.yml.template
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
# setting file for chatty

model:
name: "124M"

chatty:
params:
rolling_prompt: 2
log_conversation: True

news-feed-agent:
params:
user-agent: "linux:examplenewsagent:v0.0.1 (by /u/unknown)"

0 comments on commit 235bbf5

Please sign in to comment.