|
|
import os |
|
|
import sys |
|
|
import datetime |
|
|
|
|
|
from openai import OpenAI |
|
|
import gradio as gr |
|
|
from gradio.components.chatbot import ChatMessage, Message |
|
|
from typing import ( |
|
|
Any, |
|
|
Literal, |
|
|
) |
|
|
|
|
|
DEBUG_LOG = False or os.environ.get("DEBUG_LOG") == "True" |
|
|
|
|
|
print(f"Gradio version: {gr.__version__}") |
|
|
|
|
|
title = None |
|
|
description = None |
|
|
|
|
|
chat_start_count = 0 |
|
|
|
|
|
model_config = { |
|
|
"MODEL_NAME": os.environ.get("MODEL_NAME"), |
|
|
"MODE_DISPLAY_NAME": os.environ.get("MODE_DISPLAY_NAME"), |
|
|
"MODEL_HF_URL": os.environ.get("MODEL_HF_URL"), |
|
|
"VLLM_API_URL": os.environ.get("VLLM_API_URL"), |
|
|
"AUTH_TOKEN": os.environ.get("AUTH_TOKEN") |
|
|
} |
|
|
|
|
|
|
|
|
client = OpenAI( |
|
|
api_key=model_config.get('AUTH_TOKEN'), |
|
|
base_url=model_config.get('VLLM_API_URL') |
|
|
) |
|
|
|
|
|
|
|
|
def log_message(message): |
|
|
if DEBUG_LOG is True: |
|
|
print(message) |
|
|
|
|
|
|
|
|
|
|
|
def _check_format(messages: Any, type: Literal["messages", "tuples"] = "messages") -> None: |
|
|
if type == "messages": |
|
|
all_valid = all( |
|
|
isinstance(message, dict) |
|
|
and "role" in message |
|
|
and "content" in message |
|
|
or isinstance(message, ChatMessage | Message) |
|
|
for message in messages |
|
|
) |
|
|
if not all_valid: |
|
|
|
|
|
for i, message in enumerate(messages): |
|
|
if not (isinstance(message, dict) and |
|
|
"role" in message and |
|
|
"content" in message) and not isinstance(message, ChatMessage | Message): |
|
|
print(f"_check_format() --> Invalid message at index {i}: {message}\n", file=sys.stderr) |
|
|
break |
|
|
|
|
|
raise Exception( |
|
|
"Data incompatible with messages format. Each message should be a dictionary with 'role' and 'content' keys or a ChatMessage object." |
|
|
) |
|
|
|
|
|
|
|
|
elif not all( |
|
|
isinstance(message, (tuple, list)) and len(message) == 2 |
|
|
for message in messages |
|
|
): |
|
|
raise Exception( |
|
|
"Data incompatible with tuples format. Each message should be a list of length 2." |
|
|
) |
|
|
|
|
|
|
|
|
def chat_fn(message, history): |
|
|
log_message(f"{'-' * 80}\nchat_fn() --> Message: {message}") |
|
|
|
|
|
global chat_start_count |
|
|
chat_start_count = chat_start_count + 1 |
|
|
print( |
|
|
f"{datetime.datetime.now()}: chat_start_count: {chat_start_count}, turns: {int(len(history if history else []) / 3)}") |
|
|
|
|
|
|
|
|
log_message(f"Original History: {history}") |
|
|
_check_format(history, "messages") |
|
|
history = [item for item in history if |
|
|
not (isinstance(item, dict) and |
|
|
item.get("role") == "assistant" and |
|
|
isinstance(item.get("metadata"), dict) and |
|
|
item.get("metadata", {}).get("title") is not None)] |
|
|
log_message(f"Updated History: {history}") |
|
|
_check_format(history, "messages") |
|
|
|
|
|
history.append({"role": "user", "content": message}) |
|
|
log_message(f"History with user message: {history}") |
|
|
_check_format(history, "messages") |
|
|
|
|
|
|
|
|
stream = client.chat.completions.create( |
|
|
model=model_config.get('MODEL_NAME'), |
|
|
messages=history, |
|
|
temperature=0.8, |
|
|
stream=True |
|
|
) |
|
|
|
|
|
history.append(gr.ChatMessage( |
|
|
role="assistant", |
|
|
content="Thinking...", |
|
|
metadata={"title": "🧠 Thought"} |
|
|
)) |
|
|
log_message(f"History added thinking: {history}") |
|
|
_check_format(history, "messages") |
|
|
|
|
|
output = "" |
|
|
completion_started = False |
|
|
for chunk in stream: |
|
|
|
|
|
content = getattr(chunk.choices[0].delta, "content", "") |
|
|
output += content |
|
|
|
|
|
parts = output.split("[BEGIN FINAL RESPONSE]") |
|
|
|
|
|
if len(parts) > 1: |
|
|
if parts[1].endswith("[END FINAL RESPONSE]"): |
|
|
parts[1] = parts[1].replace("[END FINAL RESPONSE]", "") |
|
|
if parts[1].endswith("[END FINAL RESPONSE]\n<|end|>"): |
|
|
parts[1] = parts[1].replace("[END FINAL RESPONSE]\n<|end|>", "") |
|
|
|
|
|
history[-1 if not completion_started else -2] = gr.ChatMessage( |
|
|
role="assistant", |
|
|
content=parts[0], |
|
|
metadata={"title": "🧠 Thought"} |
|
|
) |
|
|
if completion_started: |
|
|
history[-1] = gr.ChatMessage( |
|
|
role="assistant", |
|
|
content=parts[1] |
|
|
) |
|
|
elif len(parts) > 1 and not completion_started: |
|
|
completion_started = True |
|
|
history.append(gr.ChatMessage( |
|
|
role="assistant", |
|
|
content=parts[1] |
|
|
)) |
|
|
|
|
|
|
|
|
messages_to_yield = history[-1:] if not completion_started else history[-2:] |
|
|
|
|
|
yield messages_to_yield |
|
|
|
|
|
log_message(f"Final History: {history}") |
|
|
_check_format(history, "messages") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(f"Running model {model_config.get('MODE_DISPLAY_NAME')} ({model_config.get('MODEL_NAME')})") |
|
|
|
|
|
gr.ChatInterface( |
|
|
chat_fn, |
|
|
title=title, |
|
|
description=description, |
|
|
theme=gr.themes.Default(primary_hue="green"), |
|
|
type="messages", |
|
|
).launch() |
|
|
|