Spaces:
Build error
Build error
Quyet
commited on
Commit
·
de337bd
1
Parent(s):
617fa8c
update chat state history, add initial greeting
Browse files
README.md
CHANGED
|
@@ -1,10 +1,10 @@
|
|
| 1 |
---
|
| 2 |
title: PsyPlus
|
| 3 |
emoji: 🤖
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version: 3.
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
license: gpl-3.0
|
|
|
|
| 1 |
---
|
| 2 |
title: PsyPlus
|
| 3 |
emoji: 🤖
|
| 4 |
+
colorFrom: pink
|
| 5 |
+
colorTo: blue
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: 3.10.0
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
license: gpl-3.0
|
app.py
CHANGED
|
@@ -4,6 +4,7 @@ import matplotlib.pyplot as plt
|
|
| 4 |
from threading import Timer
|
| 5 |
import gradio as gr
|
| 6 |
|
|
|
|
| 7 |
from transformers import (
|
| 8 |
GPT2LMHeadModel, GPT2Tokenizer,
|
| 9 |
AutoModelForSequenceClassification, AutoTokenizer,
|
|
@@ -11,6 +12,8 @@ from transformers import (
|
|
| 11 |
)
|
| 12 |
# reference: https://huggingface.co/spaces/bentrevett/emotion-prediction
|
| 13 |
# and https://huggingface.co/spaces/tareknaous/Empathetic-DialoGPT
|
|
|
|
|
|
|
| 14 |
|
| 15 |
def euc_100():
|
| 16 |
# 1,2,3. asks about the user's emotions and store data
|
|
@@ -77,16 +80,14 @@ def euc_100():
|
|
| 77 |
|
| 78 |
|
| 79 |
def load_neural_emotion_detector():
|
| 80 |
-
model_name =
|
| 81 |
model = AutoModelForSequenceClassification.from_pretrained(model_name)
|
|
|
|
| 82 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 83 |
pipe = pipeline('text-classification', model=model, tokenizer=tokenizer,
|
| 84 |
return_all_scores=True, truncation=True)
|
| 85 |
return pipe
|
| 86 |
|
| 87 |
-
def sort_predictions(predictions):
|
| 88 |
-
return sorted(predictions, key=lambda x: x['score'], reverse=True)
|
| 89 |
-
|
| 90 |
def plot_emotion_distribution(predictions):
|
| 91 |
fig, ax = plt.subplots()
|
| 92 |
ax.bar(x=[i for i, _ in enumerate(prediction)],
|
|
@@ -98,9 +99,9 @@ def plot_emotion_distribution(predictions):
|
|
| 98 |
|
| 99 |
def rulebase(text):
|
| 100 |
keywords = {
|
| 101 |
-
'life_safety': [
|
| 102 |
-
'immediacy': [
|
| 103 |
-
'manifestation': [
|
| 104 |
}
|
| 105 |
|
| 106 |
# if found dangerous kw/topics
|
|
@@ -127,7 +128,7 @@ def euc_200(text, testing=True):
|
|
| 127 |
if not testing:
|
| 128 |
pipe = load_neural_emotion_detector()
|
| 129 |
prediction = pipe(text)[0]
|
| 130 |
-
prediction =
|
| 131 |
plot_emotion_distribution(prediction)
|
| 132 |
|
| 133 |
# get the most probable emotion. TODO: modify this part, may take sum of prob. over all negative emotion
|
|
@@ -174,46 +175,58 @@ def euc_200(text, testing=True):
|
|
| 174 |
pass
|
| 175 |
|
| 176 |
|
| 177 |
-
tokenizer
|
| 178 |
-
model = GPT2LMHeadModel.from_pretrained("tareknaous/dialogpt-empathetic-dialogues")
|
| 179 |
-
model.eval()
|
| 180 |
-
|
| 181 |
-
def chat(message, history):
|
| 182 |
-
history = history or []
|
| 183 |
eos = tokenizer.eos_token
|
| 184 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
|
| 186 |
-
|
| 187 |
-
|
|
|
|
|
|
|
| 188 |
max_length=1000,
|
| 189 |
do_sample=True, top_p=0.9, temperature=0.8,
|
| 190 |
pad_token_id=tokenizer.eos_token_id)
|
| 191 |
-
response = tokenizer.decode(bot_output_ids[:,
|
| 192 |
skip_special_tokens=True)
|
|
|
|
|
|
|
| 193 |
|
| 194 |
-
history
|
| 195 |
-
|
|
|
|
| 196 |
|
| 197 |
|
| 198 |
if __name__ == '__main__':
|
| 199 |
# euc_100()
|
| 200 |
# euc_200('I am happy about my academic record.')
|
| 201 |
parser = argparse.ArgumentParser()
|
| 202 |
-
parser.add_argument('--
|
| 203 |
args = parser.parse_args()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 204 |
|
| 205 |
-
title =
|
| 206 |
-
description =
|
|
|
|
| 207 |
iface = gr.Interface(
|
| 208 |
chat,
|
| 209 |
-
[
|
| 210 |
-
[
|
| 211 |
-
|
| 212 |
-
allow_flagging=
|
| 213 |
title=title,
|
| 214 |
description=description,
|
| 215 |
)
|
| 216 |
-
if args.
|
| 217 |
iface.launch(debug=True)
|
| 218 |
else:
|
| 219 |
-
iface.launch(debug=True, server_name=
|
|
|
|
| 4 |
from threading import Timer
|
| 5 |
import gradio as gr
|
| 6 |
|
| 7 |
+
import torch
|
| 8 |
from transformers import (
|
| 9 |
GPT2LMHeadModel, GPT2Tokenizer,
|
| 10 |
AutoModelForSequenceClassification, AutoTokenizer,
|
|
|
|
| 12 |
)
|
| 13 |
# reference: https://huggingface.co/spaces/bentrevett/emotion-prediction
|
| 14 |
# and https://huggingface.co/spaces/tareknaous/Empathetic-DialoGPT
|
| 15 |
+
# gradio vs streamlit https://trojrobert.github.io/a-guide-for-deploying-and-serving-machine-learning-with-model-streamlit-vs-gradio/
|
| 16 |
+
# https://gradio.app/interface_state/
|
| 17 |
|
| 18 |
def euc_100():
|
| 19 |
# 1,2,3. asks about the user's emotions and store data
|
|
|
|
| 80 |
|
| 81 |
|
| 82 |
def load_neural_emotion_detector():
|
| 83 |
+
model_name = 'joeddav/distilbert-base-uncased-go-emotions-student'
|
| 84 |
model = AutoModelForSequenceClassification.from_pretrained(model_name)
|
| 85 |
+
model.eval()
|
| 86 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 87 |
pipe = pipeline('text-classification', model=model, tokenizer=tokenizer,
|
| 88 |
return_all_scores=True, truncation=True)
|
| 89 |
return pipe
|
| 90 |
|
|
|
|
|
|
|
|
|
|
| 91 |
def plot_emotion_distribution(predictions):
|
| 92 |
fig, ax = plt.subplots()
|
| 93 |
ax.bar(x=[i for i, _ in enumerate(prediction)],
|
|
|
|
| 99 |
|
| 100 |
def rulebase(text):
|
| 101 |
keywords = {
|
| 102 |
+
'life_safety': ['death', 'suicide', 'murder', 'to perish together', 'jump off the building'],
|
| 103 |
+
'immediacy': ['now', 'immediately', 'tomorrow', 'today'],
|
| 104 |
+
'manifestation': ['never stop', 'every moment', 'strong', 'very']
|
| 105 |
}
|
| 106 |
|
| 107 |
# if found dangerous kw/topics
|
|
|
|
| 128 |
if not testing:
|
| 129 |
pipe = load_neural_emotion_detector()
|
| 130 |
prediction = pipe(text)[0]
|
| 131 |
+
prediction = sorted(predictions, key=lambda x: x['score'], reverse=True)
|
| 132 |
plot_emotion_distribution(prediction)
|
| 133 |
|
| 134 |
# get the most probable emotion. TODO: modify this part, may take sum of prob. over all negative emotion
|
|
|
|
| 175 |
pass
|
| 176 |
|
| 177 |
|
| 178 |
+
def _chat(message, history, model, tokenizer, args):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
eos = tokenizer.eos_token
|
| 180 |
+
history = history or {
|
| 181 |
+
'text': args.greeting,
|
| 182 |
+
'input_ids': tokenizer.encode(args.greeting[-1][1] + eos, return_tensors='pt'),
|
| 183 |
+
}
|
| 184 |
+
# TODO: only take the latest X turns, otherwise the text becomes longer and takes more time to process
|
| 185 |
|
| 186 |
+
message_ids = tokenizer.encode(message + eos, return_tensors='pt')
|
| 187 |
+
history['input_ids'] = torch.cat([history['input_ids'], message_ids], dim=-1)
|
| 188 |
+
|
| 189 |
+
bot_output_ids = model.generate(history['input_ids'],
|
| 190 |
max_length=1000,
|
| 191 |
do_sample=True, top_p=0.9, temperature=0.8,
|
| 192 |
pad_token_id=tokenizer.eos_token_id)
|
| 193 |
+
response = tokenizer.decode(bot_output_ids[:, history['input_ids'].shape[-1]:][0],
|
| 194 |
skip_special_tokens=True)
|
| 195 |
+
if args.run_on_own_server == 1:
|
| 196 |
+
print((message, response), bot_output_ids[0][-10:])
|
| 197 |
|
| 198 |
+
history['input_ids'] = bot_output_ids
|
| 199 |
+
history['text'].append((message, response))
|
| 200 |
+
return history['text'], history
|
| 201 |
|
| 202 |
|
| 203 |
if __name__ == '__main__':
|
| 204 |
# euc_100()
|
| 205 |
# euc_200('I am happy about my academic record.')
|
| 206 |
parser = argparse.ArgumentParser()
|
| 207 |
+
parser.add_argument('--run_on_own_server', type=int, default=0, help='if test on own server, need to use share mode')
|
| 208 |
args = parser.parse_args()
|
| 209 |
+
args.greeting = [('','Hi you!')]
|
| 210 |
+
|
| 211 |
+
tokenizer = GPT2Tokenizer.from_pretrained('tareknaous/dialogpt-empathetic-dialogues')
|
| 212 |
+
model = GPT2LMHeadModel.from_pretrained('tareknaous/dialogpt-empathetic-dialogues')
|
| 213 |
+
model.eval()
|
| 214 |
+
def chat(message, history):
|
| 215 |
+
return _chat(message, history, model, tokenizer, args)
|
| 216 |
|
| 217 |
+
title = 'PsyPlus Empathetic Chatbot'
|
| 218 |
+
description = 'Gradio demo for product of PsyPlus. Based on rule-based CBT and conversational AI model DialoGPT'
|
| 219 |
+
chatbot = gr.Chatbot(value=args.greeting)
|
| 220 |
iface = gr.Interface(
|
| 221 |
chat,
|
| 222 |
+
['text', 'state'],
|
| 223 |
+
[chatbot, 'state'],
|
| 224 |
+
# css=".gradio-container {background-color: white}",
|
| 225 |
+
allow_flagging='never',
|
| 226 |
title=title,
|
| 227 |
description=description,
|
| 228 |
)
|
| 229 |
+
if args.run_on_own_server == 0:
|
| 230 |
iface.launch(debug=True)
|
| 231 |
else:
|
| 232 |
+
iface.launch(debug=True, server_name='0.0.0.0', server_port=2022, share=True)
|