Spaces:
Runtime error
Runtime error
Commit
·
b347ca3
1
Parent(s):
7e1bc94
Add all BioGPT app files with LFS tracking
Browse files- Dockerfile +36 -0
- app.py +1067 -0
- bpecodes +0 -0
- dict.txt +0 -0
- fast +3 -0
- hoc_best.pt +3 -0
- requirements.txt +7 -0
- templates/index.html +191 -0
Dockerfile
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Use an official Python runtime as a parent image (e.g., 3.11 or 3.12)
|
| 2 |
+
FROM python:3.11-slim
|
| 3 |
+
|
| 4 |
+
# Set the working directory in the container
|
| 5 |
+
WORKDIR /app
|
| 6 |
+
|
| 7 |
+
# Install system dependencies if needed (e.g., build-essential if compiling fastBPE)
|
| 8 |
+
# RUN apt-get update && apt-get install -y --no-install-recommends build-essential && rm -rf /var/lib/apt/lists/*
|
| 9 |
+
|
| 10 |
+
# Copy the requirements file into the container first (for Docker caching)
|
| 11 |
+
COPY requirements.txt requirements.txt
|
| 12 |
+
|
| 13 |
+
# Install Python dependencies
|
| 14 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 15 |
+
|
| 16 |
+
# Copy the application code and necessary files into the container
|
| 17 |
+
# IMPORTANT: This assumes 'fast' is pre-compiled for Linux x86_64
|
| 18 |
+
COPY app.py .
|
| 19 |
+
COPY hoc_best.pt .
|
| 20 |
+
COPY bpecodes .
|
| 21 |
+
COPY dict.txt .
|
| 22 |
+
COPY fast . # Copy the 'fast' executable directly
|
| 23 |
+
COPY templates/ ./templates/
|
| 24 |
+
|
| 25 |
+
# Ensure the 'fast' binary is executable within the container
|
| 26 |
+
RUN chmod +x ./fast
|
| 27 |
+
|
| 28 |
+
# Make port 5000 available (as defined in README.md and app.py)
|
| 29 |
+
EXPOSE 5000
|
| 30 |
+
|
| 31 |
+
# Define environment variable for Python output buffering
|
| 32 |
+
ENV PYTHONUNBUFFERED=1
|
| 33 |
+
|
| 34 |
+
# Run the app using waitress (a production-ready WSGI server)
|
| 35 |
+
# Assumes your Flask app instance is named 'app' in 'app.py'
|
| 36 |
+
CMD ["waitress-serve", "--host=0.0.0.0", "--port=5000", "app:app"]
|
app.py
ADDED
|
@@ -0,0 +1,1067 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import subprocess
|
| 3 |
+
import math
|
| 4 |
+
import difflib
|
| 5 |
+
import tempfile
|
| 6 |
+
import numpy as np
|
| 7 |
+
import pandas as pd
|
| 8 |
+
from tqdm.auto import tqdm
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
from dataclasses import dataclass
|
| 13 |
+
from sacremoses import MosesDetokenizer
|
| 14 |
+
from flask import Flask, request, jsonify, render_template
|
| 15 |
+
import traceback # Import traceback for better error logging
|
| 16 |
+
import re # Import the regular expression module
|
| 17 |
+
|
| 18 |
+
# --- Constants and Paths ---
|
| 19 |
+
# Ensure these files are in the same directory as app.py or provide correct paths
|
| 20 |
+
FINETUNED_MODEL_PATH = "hoc_best.pt"
|
| 21 |
+
BPE_CODES_PATH = "bpecodes"
|
| 22 |
+
DICT_TXT_PATH = "dict.txt"
|
| 23 |
+
FASTBPE_BIN_PATH = "./fast" # Assumes fast executable is alongside app.py
|
| 24 |
+
|
| 25 |
+
HALLMARKS = [ # Keep this consistent with training/evaluation
|
| 26 |
+
"activating invasion and metastasis", "avoiding immune destruction",
|
| 27 |
+
"cellular energetics", "enabling replicative immortality",
|
| 28 |
+
"evading growth suppressors", "genomic instability and mutation",
|
| 29 |
+
"inducing angiogenesis", "resisting cell death",
|
| 30 |
+
"sustaining proliferative signaling", "tumor promoting inflammation",
|
| 31 |
+
]
|
| 32 |
+
|
| 33 |
+
# --- Model Architecture Definitions (Copy from your notebook) ---
|
| 34 |
+
# NOTE: Make sure these classes are IDENTICAL to the ones used for training
|
| 35 |
+
# including GPTConfig, LayerNorm, CausalSelfAttention, MLP, Block, GPT, GPTWithSoftPrompt
|
| 36 |
+
|
| 37 |
+
@dataclass
|
| 38 |
+
class GPTConfig:
|
| 39 |
+
block_size: int
|
| 40 |
+
vocab_size: int
|
| 41 |
+
n_layer: int
|
| 42 |
+
n_head: int
|
| 43 |
+
n_embd: int
|
| 44 |
+
dropout: float = 0.0
|
| 45 |
+
bias: bool = True
|
| 46 |
+
|
| 47 |
+
class LayerNorm(nn.Module):
|
| 48 |
+
# (Copied from notebook)
|
| 49 |
+
def __init__(self, ndim, bias):
|
| 50 |
+
super().__init__()
|
| 51 |
+
self.weight = nn.Parameter(torch.ones(ndim))
|
| 52 |
+
self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
|
| 53 |
+
def forward(self, x):
|
| 54 |
+
return F.layer_norm(x, self.weight.shape, self.weight, self.bias, 1e-5)
|
| 55 |
+
|
| 56 |
+
class CausalSelfAttention(nn.Module):
|
| 57 |
+
# (Copied from notebook)
|
| 58 |
+
def __init__(self, config):
|
| 59 |
+
super().__init__()
|
| 60 |
+
assert config.n_embd % config.n_head == 0
|
| 61 |
+
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
|
| 62 |
+
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
|
| 63 |
+
self.attn_dropout = nn.Dropout(config.dropout)
|
| 64 |
+
self.resid_dropout = nn.Dropout(config.dropout)
|
| 65 |
+
self.n_head = config.n_head
|
| 66 |
+
self.n_embd = config.n_embd
|
| 67 |
+
self.flash = hasattr(F, 'scaled_dot_product_attention') # Check for flash attention
|
| 68 |
+
if not self.flash:
|
| 69 |
+
# print("Warning: Flash Attention not available.") # Optional warning
|
| 70 |
+
# Make the buffer persistent otherwise device mismatches during forward pass
|
| 71 |
+
self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
|
| 72 |
+
.view(1, 1, config.block_size, config.block_size), persistent=True)
|
| 73 |
+
#else:
|
| 74 |
+
# print("Using Flash Attention.") # Optional info
|
| 75 |
+
|
| 76 |
+
def forward(self, x):
|
| 77 |
+
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
|
| 78 |
+
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
|
| 79 |
+
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
|
| 80 |
+
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
| 81 |
+
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
| 82 |
+
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
| 83 |
+
|
| 84 |
+
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
|
| 85 |
+
if self.flash:
|
| 86 |
+
# efficient attention using Flash Attention CUDA kernels
|
| 87 |
+
y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.attn_dropout.p if self.training else 0.0, is_causal=True)
|
| 88 |
+
else:
|
| 89 |
+
# manual implementation of attention
|
| 90 |
+
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
| 91 |
+
# Ensure bias buffer is used correctly
|
| 92 |
+
# Check if bias buffer exists before using it
|
| 93 |
+
if hasattr(self, 'bias'):
|
| 94 |
+
att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
|
| 95 |
+
else:
|
| 96 |
+
# Fallback if somehow bias wasn't registered (shouldn't happen with persistent=True)
|
| 97 |
+
mask = torch.tril(torch.ones(T, T, device=x.device)).view(1, 1, T, T)
|
| 98 |
+
att = att.masked_fill(mask == 0, float('-inf'))
|
| 99 |
+
|
| 100 |
+
att = F.softmax(att, dim=-1)
|
| 101 |
+
att = self.attn_dropout(att)
|
| 102 |
+
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
|
| 103 |
+
|
| 104 |
+
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
|
| 105 |
+
# output projection
|
| 106 |
+
y = self.resid_dropout(self.c_proj(y))
|
| 107 |
+
return y
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class MLP(nn.Module):
|
| 111 |
+
# (Copied from notebook)
|
| 112 |
+
def __init__(self, config):
|
| 113 |
+
super().__init__()
|
| 114 |
+
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
|
| 115 |
+
self.gelu = nn.GELU()
|
| 116 |
+
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
|
| 117 |
+
self.dropout = nn.Dropout(config.dropout)
|
| 118 |
+
def forward(self, x):
|
| 119 |
+
x = self.c_fc(x)
|
| 120 |
+
x = self.gelu(x)
|
| 121 |
+
x = self.c_proj(x)
|
| 122 |
+
x = self.dropout(x)
|
| 123 |
+
return x
|
| 124 |
+
|
| 125 |
+
class Block(nn.Module):
|
| 126 |
+
# (Copied from notebook)
|
| 127 |
+
def __init__(self, config):
|
| 128 |
+
super().__init__()
|
| 129 |
+
self.ln1 = LayerNorm(config.n_embd, bias=config.bias)
|
| 130 |
+
self.attn = CausalSelfAttention(config)
|
| 131 |
+
self.ln2 = LayerNorm(config.n_embd, bias=config.bias)
|
| 132 |
+
self.mlp = MLP(config)
|
| 133 |
+
def forward(self, x):
|
| 134 |
+
x = x + self.attn(self.ln1(x))
|
| 135 |
+
x = x + self.mlp(self.ln2(x))
|
| 136 |
+
return x
|
| 137 |
+
|
| 138 |
+
class GPT(nn.Module):
|
| 139 |
+
# (Copied from notebook - simplified _init_weights and removed generate)
|
| 140 |
+
def __init__(self, config):
|
| 141 |
+
super().__init__()
|
| 142 |
+
#assert config.vocab_size is not None
|
| 143 |
+
#assert config.block_size is not None
|
| 144 |
+
self.config = config
|
| 145 |
+
|
| 146 |
+
self.transformer = nn.ModuleDict(dict(
|
| 147 |
+
wte = nn.Embedding(config.vocab_size, config.n_embd),
|
| 148 |
+
wpe = nn.Embedding(config.block_size, config.n_embd),
|
| 149 |
+
drop = nn.Dropout(config.dropout),
|
| 150 |
+
h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
|
| 151 |
+
ln_f = LayerNorm(config.n_embd, bias=config.bias),
|
| 152 |
+
))
|
| 153 |
+
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
| 154 |
+
self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying
|
| 155 |
+
|
| 156 |
+
# init all weights
|
| 157 |
+
self.apply(self._init_weights)
|
| 158 |
+
# apply special scaled init to the residual projections, per GPT-2 paper
|
| 159 |
+
for pn, p in self.named_parameters():
|
| 160 |
+
if pn.endswith('c_proj.weight'):
|
| 161 |
+
torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
|
| 162 |
+
|
| 163 |
+
def _init_weights(self, module):
|
| 164 |
+
if isinstance(module, nn.Linear):
|
| 165 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 166 |
+
if module.bias is not None:
|
| 167 |
+
torch.nn.init.zeros_(module.bias)
|
| 168 |
+
elif isinstance(module, nn.Embedding):
|
| 169 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 170 |
+
|
| 171 |
+
def forward(self, idx, targets=None):
|
| 172 |
+
device = idx.device
|
| 173 |
+
b, t = idx.size()
|
| 174 |
+
if t > self.config.block_size:
|
| 175 |
+
# Crop sequence if longer than block size
|
| 176 |
+
print(f"Warning: Input sequence length ({t}) > block size ({self.config.block_size}). Cropping.")
|
| 177 |
+
idx = idx[:, -self.config.block_size:]
|
| 178 |
+
t = self.config.block_size
|
| 179 |
+
#assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
|
| 180 |
+
pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t)
|
| 181 |
+
|
| 182 |
+
# forward the GPT model itself
|
| 183 |
+
tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
|
| 184 |
+
pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
|
| 185 |
+
x = self.transformer.drop(tok_emb + pos_emb)
|
| 186 |
+
for block in self.transformer.h:
|
| 187 |
+
x = block(x)
|
| 188 |
+
x = self.transformer.ln_f(x)
|
| 189 |
+
|
| 190 |
+
if targets is not None:
|
| 191 |
+
# if we are given some desired targets also calculate the loss
|
| 192 |
+
logits = self.lm_head(x)
|
| 193 |
+
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
|
| 194 |
+
return logits, loss
|
| 195 |
+
else:
|
| 196 |
+
# inference-time mini-optimization: only forward the lm_head on the very last position
|
| 197 |
+
logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
|
| 198 |
+
# Check for NaN/Inf in logits before returning
|
| 199 |
+
if torch.isnan(logits).any() or torch.isinf(logits).any():
|
| 200 |
+
print("WARNING: NaN or Inf detected in logits during inference.")
|
| 201 |
+
# Handle appropriately - maybe return an error indicator or zero logits?
|
| 202 |
+
# For now, just print warning.
|
| 203 |
+
return logits, None
|
| 204 |
+
|
| 205 |
+
class GPTWithSoftPrompt(nn.Module):
|
| 206 |
+
# (Copied from notebook - simplified)
|
| 207 |
+
def __init__(self, base_gpt: GPT, prompt_len=1):
|
| 208 |
+
super().__init__()
|
| 209 |
+
self.config = base_gpt.config
|
| 210 |
+
self.transformer = base_gpt.transformer
|
| 211 |
+
self.lm_head = base_gpt.lm_head
|
| 212 |
+
C = self.config.n_embd
|
| 213 |
+
self.soft_prompt = nn.Parameter(torch.zeros(1, prompt_len, C)) # Keep on CPU first
|
| 214 |
+
nn.init.normal_(self.soft_prompt, mean=0.0, std=0.02)
|
| 215 |
+
|
| 216 |
+
def forward(self, idx, targets=None):
|
| 217 |
+
B, T = idx.shape
|
| 218 |
+
device = idx.device # Get device from input tensor
|
| 219 |
+
|
| 220 |
+
# Make sure soft_prompt is on the same device as input
|
| 221 |
+
soft_prompt_on_device = self.soft_prompt.to(device)
|
| 222 |
+
|
| 223 |
+
# token + pos
|
| 224 |
+
tok_emb = self.transformer.wte(idx) # (B,T,C)
|
| 225 |
+
pos = torch.arange(0, T, dtype=torch.long, device=device)
|
| 226 |
+
pos_emb = self.transformer.wpe(pos) # (T,C)
|
| 227 |
+
x_tokens = tok_emb + pos_emb
|
| 228 |
+
|
| 229 |
+
# prepend soft prompt
|
| 230 |
+
soft = soft_prompt_on_device.expand(B, -1, -1) # (B,P,C)
|
| 231 |
+
|
| 232 |
+
# --- FIX: Define P before the if/else block ---
|
| 233 |
+
P = soft.size(1) # Get soft prompt length
|
| 234 |
+
|
| 235 |
+
x = torch.cat([soft, x_tokens], dim=1) # (B,P+T,C)
|
| 236 |
+
|
| 237 |
+
# --- Standard Transformer forward pass ---
|
| 238 |
+
x = self.transformer.drop(x)
|
| 239 |
+
for block in self.transformer.h:
|
| 240 |
+
x = block(x)
|
| 241 |
+
x = self.transformer.ln_f(x)
|
| 242 |
+
logits = self.lm_head(x) # (B,P+T,V)
|
| 243 |
+
# --- End Standard ---
|
| 244 |
+
|
| 245 |
+
if targets is None:
|
| 246 |
+
# Inference: return logits for the last token of the *original* sequence
|
| 247 |
+
# We need the prediction *after* the last input token, which is at index T (P+T-1 overall)
|
| 248 |
+
# Use P which is now defined
|
| 249 |
+
# Ensure index is within bounds
|
| 250 |
+
target_logit_index = P + T - 1
|
| 251 |
+
if target_logit_index >= logits.size(1):
|
| 252 |
+
print(f"Warning: Calculated logit index {target_logit_index} out of bounds for logits shape {logits.shape}. Returning last logit.")
|
| 253 |
+
target_logit_index = -1 # Fallback to last logit
|
| 254 |
+
|
| 255 |
+
final_logits = logits[:, target_logit_index, :]
|
| 256 |
+
# Check for NaN/Inf
|
| 257 |
+
if torch.isnan(final_logits).any() or torch.isinf(final_logits).any():
|
| 258 |
+
print(f"WARNING: NaN or Inf detected in final_logits at index {target_logit_index}.")
|
| 259 |
+
# Handle appropriately - maybe return zeros or raise an error?
|
| 260 |
+
# For now, just print warning. Let the calling function handle it.
|
| 261 |
+
|
| 262 |
+
return final_logits, None # Return (B, V)
|
| 263 |
+
else:
|
| 264 |
+
# Training loss calculation (copied from notebook)
|
| 265 |
+
# P is already defined above
|
| 266 |
+
pad_ignore = torch.full((B, P), -1, dtype=targets.dtype, device=device)
|
| 267 |
+
full_targets = torch.cat([pad_ignore, targets], dim=1)
|
| 268 |
+
logits_lm = logits[:, :-1, :].contiguous()
|
| 269 |
+
targets_lm = full_targets[:, 1:].contiguous()
|
| 270 |
+
loss = F.cross_entropy(
|
| 271 |
+
logits_lm.view(-1, logits_lm.size(-1)),
|
| 272 |
+
targets_lm.view(-1),
|
| 273 |
+
ignore_index=-1
|
| 274 |
+
)
|
| 275 |
+
# Check for NaN/Inf in loss
|
| 276 |
+
if torch.isnan(loss) or torch.isinf(loss):
|
| 277 |
+
print("WARNING: NaN or Inf detected in loss calculation.")
|
| 278 |
+
# Potentially add debugging info here (e.g., print shapes, inputs)
|
| 279 |
+
return logits, loss
|
| 280 |
+
|
| 281 |
+
# --- Constrained generation method (from Section 2.9) ---
|
| 282 |
+
@torch.no_grad()
|
| 283 |
+
def generate_labels(self, idx, allowed_mask, max_new_tokens=24, temperature=0.0):
|
| 284 |
+
self.eval() # Ensure model is in eval mode
|
| 285 |
+
B = idx.size(0)
|
| 286 |
+
# Add soft prompt length to effective block size consideration
|
| 287 |
+
P = self.soft_prompt.size(1)
|
| 288 |
+
# Correct effective block size based on GPT class logic
|
| 289 |
+
effective_block_size = self.config.block_size # GPT forward handles cropping
|
| 290 |
+
|
| 291 |
+
# Start with input index
|
| 292 |
+
out = idx.clone() # Clone to avoid modifying original input
|
| 293 |
+
|
| 294 |
+
# Ensure allowed_mask is on the correct device
|
| 295 |
+
allowed_mask = allowed_mask.to(idx.device)
|
| 296 |
+
|
| 297 |
+
finished = torch.zeros(B, dtype=torch.bool, device=idx.device)
|
| 298 |
+
|
| 299 |
+
# Get global eos_id safely
|
| 300 |
+
global eos_id
|
| 301 |
+
current_eos_id = eos_id # Use the globally loaded eos_id
|
| 302 |
+
|
| 303 |
+
for step in range(max_new_tokens):
|
| 304 |
+
# Crop context if it exceeds model's block size (GPT forward handles this internally now)
|
| 305 |
+
# ctx = out if out.size(1) <= effective_block_size else out[:, -effective_block_size:]
|
| 306 |
+
ctx = out # Pass the current sequence
|
| 307 |
+
|
| 308 |
+
# Forward pass - expects shape (B, T), model handles soft prompt internally
|
| 309 |
+
# It returns logits for the *next* token prediction after the last token in ctx
|
| 310 |
+
logits, _ = self(ctx) # Gets logits for last token prediction, shape (B, V)
|
| 311 |
+
|
| 312 |
+
# Check for NaN/Inf in logits
|
| 313 |
+
if torch.isnan(logits).any() or torch.isinf(logits).any():
|
| 314 |
+
print(f"WARNING: NaN or Inf detected in logits during generation step {step}. Stopping generation.")
|
| 315 |
+
# Return what we have so far, excluding potentially bad last token
|
| 316 |
+
return out[:, idx.size(1):] # Or handle error differently
|
| 317 |
+
|
| 318 |
+
# Apply constraint mask
|
| 319 |
+
# Ensure mask shape matches logits shape
|
| 320 |
+
if logits.shape != allowed_mask.shape:
|
| 321 |
+
print(f"Warning: Logits shape {logits.shape} doesn't match mask shape {allowed_mask.shape}. Reshaping mask.")
|
| 322 |
+
# This assumes mask needs batch dim added
|
| 323 |
+
current_mask = allowed_mask.unsqueeze(0).expand_as(logits)
|
| 324 |
+
else:
|
| 325 |
+
current_mask = allowed_mask
|
| 326 |
+
|
| 327 |
+
logits = logits + current_mask
|
| 328 |
+
|
| 329 |
+
# Sample next token
|
| 330 |
+
if temperature <= 0:
|
| 331 |
+
# Greedy decoding
|
| 332 |
+
next_id = torch.argmax(logits, dim=-1) # (B,)
|
| 333 |
+
else:
|
| 334 |
+
# Temperature sampling
|
| 335 |
+
probs = F.softmax(logits / temperature, dim=-1)
|
| 336 |
+
# Check for NaN/Inf in probs
|
| 337 |
+
if torch.isnan(probs).any() or torch.isinf(probs).any():
|
| 338 |
+
print(f"WARNING: NaN or Inf detected in probabilities during generation step {step}. Using argmax fallback.")
|
| 339 |
+
next_id = torch.argmax(logits, dim=-1) # Fallback to greedy
|
| 340 |
+
else:
|
| 341 |
+
try:
|
| 342 |
+
next_id = torch.multinomial(probs, num_samples=1).squeeze(1) # (B,)
|
| 343 |
+
except RuntimeError as e:
|
| 344 |
+
print(f"WARNING: torch.multinomial failed: {e}. Using argmax fallback.")
|
| 345 |
+
next_id = torch.argmax(logits, dim=-1) # Fallback to greedy
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
# Handle finished sequences (force EOS) and update output
|
| 349 |
+
# Check if current_eos_id is valid
|
| 350 |
+
if not isinstance(current_eos_id, int):
|
| 351 |
+
print(f"Warning: Global eos_id is not an integer ({current_eos_id}). Defaulting to 0.")
|
| 352 |
+
current_eos_id = 0
|
| 353 |
+
next_id = next_id.masked_fill(finished, current_eos_id) # Use the validated eos_id
|
| 354 |
+
|
| 355 |
+
# Check if next_id contains invalid values (e.g., negative)
|
| 356 |
+
if (next_id < 0).any():
|
| 357 |
+
print(f"WARNING: Negative token ID generated: {next_id}. Clipping to 0.")
|
| 358 |
+
next_id = torch.clamp(next_id, min=0)
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
# Append the next token ID
|
| 362 |
+
out = torch.cat([out, next_id.unsqueeze(1)], dim=1)
|
| 363 |
+
|
| 364 |
+
# Update finished status
|
| 365 |
+
finished |= (next_id == current_eos_id)
|
| 366 |
+
|
| 367 |
+
# Stop if all sequences in the batch are finished
|
| 368 |
+
if bool(finished.all()):
|
| 369 |
+
# print(f"Generation finished early at step {step+1}") # Optional debug info
|
| 370 |
+
break
|
| 371 |
+
# else:
|
| 372 |
+
# print(f"Generation reached max_new_tokens ({max_new_tokens})") # Optional debug info
|
| 373 |
+
|
| 374 |
+
# Return only the generated part (excluding the initial idx length)
|
| 375 |
+
return out[:, idx.size(1):]
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
# --- Tokenizer Helper Functions ---
|
| 379 |
+
# Added robustness and error checks
|
| 380 |
+
# Global tokenizer maps and special IDs, loaded once at startup
|
| 381 |
+
token2id, id2token = {}, {}
|
| 382 |
+
eos_id = 0 # Default, will be overwritten
|
| 383 |
+
pad_id = 0 # Default, will be overwritten
|
| 384 |
+
detokenizer = None
|
| 385 |
+
|
| 386 |
+
def load_tokenizer_data(dict_path):
|
| 387 |
+
global token2id, id2token, eos_id, pad_id, detokenizer
|
| 388 |
+
print(f"Loading vocabulary from {dict_path}...")
|
| 389 |
+
local_token2id, local_id2token = {}, {}
|
| 390 |
+
try:
|
| 391 |
+
with open(dict_path, encoding="utf-8") as f:
|
| 392 |
+
for i, line in enumerate(f):
|
| 393 |
+
parts = line.split() # Split by whitespace
|
| 394 |
+
if not parts: continue # Skip empty lines
|
| 395 |
+
tok = parts[0]
|
| 396 |
+
if tok in local_token2id:
|
| 397 |
+
print(f"Warning: Duplicate token '{tok}' found at line {i+1}. Keeping first occurrence.")
|
| 398 |
+
continue
|
| 399 |
+
local_token2id[tok] = i
|
| 400 |
+
local_id2token[i] = tok
|
| 401 |
+
|
| 402 |
+
# Assign to global variables only after successful loading
|
| 403 |
+
token2id = local_token2id
|
| 404 |
+
id2token = local_id2token
|
| 405 |
+
|
| 406 |
+
# Use a known special token ID if </s> is missing, otherwise default might be wrong
|
| 407 |
+
# Try multiple common EOS tokens
|
| 408 |
+
possible_eos = ["</s>", "<|endoftext|>", "[EOS]"]
|
| 409 |
+
found_eos = False
|
| 410 |
+
for eos_tok in possible_eos:
|
| 411 |
+
if eos_tok in token2id:
|
| 412 |
+
eos_id = token2id[eos_tok]
|
| 413 |
+
found_eos = True
|
| 414 |
+
print(f"Found EOS token '{eos_tok}' with ID: {eos_id}")
|
| 415 |
+
break
|
| 416 |
+
if not found_eos:
|
| 417 |
+
# If no common EOS found, fall back to the highest index or 0
|
| 418 |
+
eos_id = max(token2id.values()) if token2id else 0
|
| 419 |
+
print(f"Warning: Standard EOS tokens not found. Using highest index ({eos_id}) as EOS ID.")
|
| 420 |
+
|
| 421 |
+
# Assign pad_id, often same as eos_id or a specific <pad> token
|
| 422 |
+
pad_id = token2id.get("<pad>", eos_id) # Prefer <pad> if exists, else use eos_id
|
| 423 |
+
print(f"Using PAD ID: {pad_id}")
|
| 424 |
+
|
| 425 |
+
detokenizer = MosesDetokenizer(lang='en') # Initialize once
|
| 426 |
+
print(f"Vocabulary loaded. Size: {len(token2id)}")
|
| 427 |
+
if not detokenizer:
|
| 428 |
+
raise ValueError("MosesDetokenizer failed to initialize.")
|
| 429 |
+
|
| 430 |
+
except FileNotFoundError:
|
| 431 |
+
print(f"ERROR: Vocabulary file not found at {dict_path}")
|
| 432 |
+
raise
|
| 433 |
+
except Exception as e:
|
| 434 |
+
print(f"ERROR: Failed to load tokenizer data from {dict_path}: {e}")
|
| 435 |
+
raise
|
| 436 |
+
|
| 437 |
+
def bpe_encode_lines(lines, shard_size=500, desc="BPE Encode"):
|
| 438 |
+
""" Encodes lines using external fastBPE binary. Added error checking. """
|
| 439 |
+
global BPE_CODES_PATH, FASTBPE_BIN_PATH
|
| 440 |
+
# --- Input Validation ---
|
| 441 |
+
if not isinstance(lines, list):
|
| 442 |
+
print(f"Warning: bpe_encode_lines expected a list, got {type(lines)}. Attempting conversion.")
|
| 443 |
+
try:
|
| 444 |
+
lines = list(lines)
|
| 445 |
+
except TypeError:
|
| 446 |
+
raise ValueError("Input 'lines' must be a list or convertible to a list.")
|
| 447 |
+
|
| 448 |
+
if not lines: return []
|
| 449 |
+
|
| 450 |
+
# --- Path and Executable Checks ---
|
| 451 |
+
abs_fastbpe_path = os.path.abspath(FASTBPE_BIN_PATH)
|
| 452 |
+
abs_bpe_codes_path = os.path.abspath(BPE_CODES_PATH)
|
| 453 |
+
|
| 454 |
+
if not os.path.exists(abs_fastbpe_path):
|
| 455 |
+
raise FileNotFoundError(f"fastBPE executable not found at {abs_fastbpe_path}")
|
| 456 |
+
if not os.path.exists(abs_bpe_codes_path):
|
| 457 |
+
raise FileNotFoundError(f"BPE codes file not found at {abs_bpe_codes_path}")
|
| 458 |
+
if not os.access(abs_fastbpe_path, os.X_OK):
|
| 459 |
+
print(f"Warning: fastBPE binary at {abs_fastbpe_path} is not executable. Attempting chmod...")
|
| 460 |
+
try:
|
| 461 |
+
os.chmod(abs_fastbpe_path, 0o755)
|
| 462 |
+
except OSError as e:
|
| 463 |
+
raise PermissionError(f"Failed to make fastBPE executable: {e}. Please check permissions.")
|
| 464 |
+
|
| 465 |
+
out_tokens = []
|
| 466 |
+
# Process in chunks
|
| 467 |
+
with tempfile.TemporaryDirectory() as td:
|
| 468 |
+
for start in range(0, len(lines), shard_size):
|
| 469 |
+
chunk = lines[start:start+shard_size]
|
| 470 |
+
src_path = os.path.join(td, f"src_{start}.txt")
|
| 471 |
+
dst_path = os.path.join(td, f"dst_{start}.bpe")
|
| 472 |
+
|
| 473 |
+
try:
|
| 474 |
+
# Write chunk to temp file, ensuring strings
|
| 475 |
+
with open(src_path, "w", encoding="utf-8") as f:
|
| 476 |
+
for s in chunk:
|
| 477 |
+
f.write(str(s or "").strip() + "\n") # Ensure string conversion
|
| 478 |
+
|
| 479 |
+
# Run fastBPE
|
| 480 |
+
cmd = [abs_fastbpe_path, "applybpe", dst_path, src_path, abs_bpe_codes_path]
|
| 481 |
+
# print(f"Running command: {' '.join(cmd)}") # Debug command
|
| 482 |
+
process = subprocess.run(
|
| 483 |
+
cmd,
|
| 484 |
+
capture_output=True, text=True, check=False # Don't check=True here, handle error below
|
| 485 |
+
)
|
| 486 |
+
|
| 487 |
+
# Check for errors specifically
|
| 488 |
+
if process.returncode != 0:
|
| 489 |
+
# Log more details on failure
|
| 490 |
+
print(f"ERROR: fastBPE failed (exit code {process.returncode}) on chunk starting at index {start}.")
|
| 491 |
+
print(f"Command: {' '.join(cmd)}")
|
| 492 |
+
print(f"Stderr:\n{process.stderr}")
|
| 493 |
+
# Optionally print some input data
|
| 494 |
+
print(f"First line of input chunk: {chunk[0] if chunk else 'N/A'}")
|
| 495 |
+
raise subprocess.CalledProcessError(process.returncode, cmd, output=process.stdout, stderr=process.stderr)
|
| 496 |
+
|
| 497 |
+
# Read results if successful
|
| 498 |
+
with open(dst_path, "r", encoding="utf-8") as f:
|
| 499 |
+
for line in f:
|
| 500 |
+
out_tokens.append(line.strip().split())
|
| 501 |
+
|
| 502 |
+
except subprocess.CalledProcessError as e:
|
| 503 |
+
# Handle specific subprocess errors (already printed details)
|
| 504 |
+
raise # Re-raise to stop execution
|
| 505 |
+
except Exception as e:
|
| 506 |
+
print(f"ERROR: Unexpected error during BPE encoding chunk starting at index {start}: {e}")
|
| 507 |
+
traceback.print_exc() # Print full traceback for unexpected errors
|
| 508 |
+
raise # Re-raise
|
| 509 |
+
return out_tokens
|
| 510 |
+
|
| 511 |
+
|
| 512 |
+
def tokens_to_ids(bpe_tokens):
|
| 513 |
+
""" Converts BPE token strings to IDs using the global map. Added checks. """
|
| 514 |
+
global token2id, pad_id
|
| 515 |
+
if not isinstance(bpe_tokens, list):
|
| 516 |
+
raise ValueError(f"Input 'bpe_tokens' must be a list, got {type(bpe_tokens)}.")
|
| 517 |
+
|
| 518 |
+
ids = []
|
| 519 |
+
oov_count = 0
|
| 520 |
+
for t in bpe_tokens:
|
| 521 |
+
if not isinstance(t, str):
|
| 522 |
+
print(f"Warning: Non-string token found in bpe_tokens: {t}. Using pad_id.")
|
| 523 |
+
ids.append(pad_id)
|
| 524 |
+
oov_count += 1
|
| 525 |
+
continue
|
| 526 |
+
|
| 527 |
+
id_val = token2id.get(t, pad_id)
|
| 528 |
+
ids.append(id_val)
|
| 529 |
+
if id_val == pad_id and t not in token2id:
|
| 530 |
+
oov_count += 1
|
| 531 |
+
# print(f"Warning: OOV token '{t}' mapped to pad_id {pad_id}") # Reduce noise
|
| 532 |
+
if oov_count > 0:
|
| 533 |
+
print(f"Info: Found {oov_count} OOV tokens in sequence of length {len(bpe_tokens)}.")
|
| 534 |
+
return ids, oov_count
|
| 535 |
+
|
| 536 |
+
def ids_to_tokens(ids):
|
| 537 |
+
""" Converts IDs back to token strings. Added checks. """
|
| 538 |
+
global id2token
|
| 539 |
+
if not isinstance(ids, list):
|
| 540 |
+
raise ValueError(f"Input 'ids' must be a list, got {type(ids)}.")
|
| 541 |
+
|
| 542 |
+
tokens = []
|
| 543 |
+
for i in ids:
|
| 544 |
+
# Ensure ID is a valid integer before lookup
|
| 545 |
+
try:
|
| 546 |
+
# Handle potential floats or NaNs from generation
|
| 547 |
+
if isinstance(i, float) and math.isnan(i):
|
| 548 |
+
token = "<nan>"
|
| 549 |
+
else:
|
| 550 |
+
int_i = int(i)
|
| 551 |
+
token = id2token.get(int_i, "<unk>")
|
| 552 |
+
except (ValueError, TypeError):
|
| 553 |
+
print(f"Warning: Could not convert ID '{i}' to int. Using '<unk>'.")
|
| 554 |
+
token = "<unk>"
|
| 555 |
+
tokens.append(token)
|
| 556 |
+
return tokens
|
| 557 |
+
|
| 558 |
+
|
| 559 |
+
def bpe_decode_tokens(bpe_tokens):
|
| 560 |
+
""" Converts BPE token strings back to readable text. Added checks. """
|
| 561 |
+
global detokenizer
|
| 562 |
+
if detokenizer is None:
|
| 563 |
+
raise RuntimeError("Detokenizer not initialized. Call load_tokenizer_data first.")
|
| 564 |
+
if not isinstance(bpe_tokens, list):
|
| 565 |
+
raise ValueError(f"Input 'bpe_tokens' must be a list, got {type(bpe_tokens)}.")
|
| 566 |
+
|
| 567 |
+
# Ensure all items are strings before joining
|
| 568 |
+
try:
|
| 569 |
+
str_tokens = [str(t) for t in bpe_tokens]
|
| 570 |
+
except Exception as e:
|
| 571 |
+
print(f"Error converting tokens to strings: {e}. Tokens: {bpe_tokens}")
|
| 572 |
+
return "<decoding error>"
|
| 573 |
+
|
| 574 |
+
s = ' '.join(str_tokens).replace('@@ ', '')
|
| 575 |
+
try:
|
| 576 |
+
# Detokenizer might fail on empty or unusual input
|
| 577 |
+
return detokenizer.detokenize(s.split()) if s.strip() else ""
|
| 578 |
+
except Exception as e:
|
| 579 |
+
print(f"Error during detokenization: {e}. Input string: '{s}'")
|
| 580 |
+
return "<detokenization error>"
|
| 581 |
+
|
| 582 |
+
|
| 583 |
+
# --- Prediction Helper Functions ---
|
| 584 |
+
|
| 585 |
+
def to_canonical(pred_chunk: str):
|
| 586 |
+
""" Maps a predicted text chunk to a canonical hallmark name. Added checks. """
|
| 587 |
+
global HALLMARKS
|
| 588 |
+
# Ensure input is a string
|
| 589 |
+
if not isinstance(pred_chunk, str):
|
| 590 |
+
# print(f"Warning: to_canonical received non-string input: {pred_chunk}. Returning None.")
|
| 591 |
+
return None
|
| 592 |
+
|
| 593 |
+
s = pred_chunk.strip().lower()
|
| 594 |
+
low = [L.lower() for L in HALLMARKS]
|
| 595 |
+
if not s: return None
|
| 596 |
+
|
| 597 |
+
if s in low:
|
| 598 |
+
return HALLMARKS[low.index(s)]
|
| 599 |
+
|
| 600 |
+
# Use difflib for fuzzy matching
|
| 601 |
+
try:
|
| 602 |
+
best = difflib.get_close_matches(s, low, n=1, cutoff=0.7)
|
| 603 |
+
return HALLMARKS[low.index(best[0])] if best else None
|
| 604 |
+
except Exception as e:
|
| 605 |
+
print(f"Error during difflib matching for '{s}': {e}")
|
| 606 |
+
return None # Return None on error
|
| 607 |
+
|
| 608 |
+
def build_allowed_token_mask(vocab_size, device):
|
| 609 |
+
""" Builds the mask for constrained decoding. Added error checks. """
|
| 610 |
+
global HALLMARKS, token2id, eos_id, pad_id
|
| 611 |
+
allowed = set()
|
| 612 |
+
|
| 613 |
+
# --- Input Validation ---
|
| 614 |
+
if vocab_size <= 0:
|
| 615 |
+
raise ValueError("Vocabulary size must be positive.")
|
| 616 |
+
if not token2id:
|
| 617 |
+
raise RuntimeError("Tokenizer vocabulary (token2id) not loaded.")
|
| 618 |
+
|
| 619 |
+
print("Encoding hallmarks for mask...")
|
| 620 |
+
try:
|
| 621 |
+
# Ensure HALLMARKS is a list of strings
|
| 622 |
+
if not isinstance(HALLMARKS, list) or not all(isinstance(h, str) for h in HALLMARKS):
|
| 623 |
+
raise ValueError("HALLMARKS must be a list of strings.")
|
| 624 |
+
hallmark_bpes = bpe_encode_lines(HALLMARKS, desc="BPE Hallmarks (for mask)")
|
| 625 |
+
for bpe_list in hallmark_bpes:
|
| 626 |
+
ids, _ = tokens_to_ids(bpe_list)
|
| 627 |
+
allowed.update(ids)
|
| 628 |
+
print(f"Encoded {len(HALLMARKS)} hallmarks.")
|
| 629 |
+
except Exception as e:
|
| 630 |
+
print(f"ERROR: Failed to BPE encode or convert hallmarks for mask: {e}")
|
| 631 |
+
raise
|
| 632 |
+
|
| 633 |
+
print("Encoding separators for mask...")
|
| 634 |
+
SEPS = [", ", ",", "; ", ";", "|"]
|
| 635 |
+
try:
|
| 636 |
+
sep_bpes = bpe_encode_lines(SEPS, desc="BPE Separators (for mask)")
|
| 637 |
+
for bpe_list in sep_bpes:
|
| 638 |
+
ids, _ = tokens_to_ids(bpe_list)
|
| 639 |
+
allowed.update(ids)
|
| 640 |
+
print(f"Encoded {len(SEPS)} separators.")
|
| 641 |
+
except Exception as e:
|
| 642 |
+
print(f"ERROR: Failed to BPE encode or convert separators for mask: {e}")
|
| 643 |
+
raise
|
| 644 |
+
|
| 645 |
+
# Add EOS token - Check if eos_id is valid
|
| 646 |
+
if not isinstance(eos_id, int) or eos_id < 0 or eos_id >= vocab_size:
|
| 647 |
+
print(f"Warning: Invalid EOS ID ({eos_id}). Defaulting mask EOS to 0.")
|
| 648 |
+
effective_eos_id = 0
|
| 649 |
+
else:
|
| 650 |
+
effective_eos_id = eos_id
|
| 651 |
+
allowed.add(effective_eos_id)
|
| 652 |
+
print(f"Total allowed token IDs (including EOS {effective_eos_id}): {len(allowed)}")
|
| 653 |
+
|
| 654 |
+
# Create the mask tensor on CPU first
|
| 655 |
+
mask = torch.full((vocab_size,), float('-inf'), device=torch.device('cpu'))
|
| 656 |
+
try:
|
| 657 |
+
# Filter out potential invalid IDs before creating list for indexing
|
| 658 |
+
# Ensure pad_id is valid if used for filtering
|
| 659 |
+
effective_pad_id = pad_id if isinstance(pad_id, int) and 0 <= pad_id < vocab_size else -1 # Use -1 if pad_id is invalid
|
| 660 |
+
|
| 661 |
+
valid_allowed_ids = []
|
| 662 |
+
for id_ in allowed:
|
| 663 |
+
if isinstance(id_, int) and 0 <= id_ < vocab_size: # Check type and range
|
| 664 |
+
# Filter out pad_id unless it's the same as the effective_eos_id
|
| 665 |
+
if id_ != effective_pad_id or id_ == effective_eos_id:
|
| 666 |
+
valid_allowed_ids.append(id_)
|
| 667 |
+
# else: print(f"Warning: Invalid ID {id_} in allowed set skipped.") # Reduce noise
|
| 668 |
+
|
| 669 |
+
if not valid_allowed_ids:
|
| 670 |
+
raise ValueError("No valid token IDs found to allow in the mask.")
|
| 671 |
+
|
| 672 |
+
# Check ranges again after filtering (belt and braces)
|
| 673 |
+
max_valid_id = max(valid_allowed_ids)
|
| 674 |
+
min_valid_id = min(valid_allowed_ids)
|
| 675 |
+
if max_valid_id >= vocab_size or min_valid_id < 0:
|
| 676 |
+
# This should ideally not happen if filtering worked
|
| 677 |
+
raise IndexError(f"Filtered allowed IDs still out of range [{min_valid_id}, {max_valid_id}] for vocab size {vocab_size}.")
|
| 678 |
+
|
| 679 |
+
# Apply mask
|
| 680 |
+
mask[valid_allowed_ids] = 0.0 # Use list directly
|
| 681 |
+
print(f"Mask created with {len(valid_allowed_ids)} allowed indices.")
|
| 682 |
+
|
| 683 |
+
except IndexError as e:
|
| 684 |
+
print(f"ERROR: Index error while creating mask. Vocab size: {vocab_size}. Error: {e}")
|
| 685 |
+
# Find problematic IDs more carefully
|
| 686 |
+
problem_ids = [i for i in allowed if not isinstance(i, int) or i < 0 or i >= vocab_size]
|
| 687 |
+
print(f"Problematic IDs in allowed set: {problem_ids}")
|
| 688 |
+
raise
|
| 689 |
+
except Exception as e:
|
| 690 |
+
print(f"ERROR: Unexpected error creating mask: {e}")
|
| 691 |
+
traceback.print_exc()
|
| 692 |
+
raise
|
| 693 |
+
|
| 694 |
+
# Move final mask to target device
|
| 695 |
+
try:
|
| 696 |
+
target_device = torch.device(device) # Ensure device is a torch.device object
|
| 697 |
+
return mask.to(target_device)
|
| 698 |
+
except Exception as e:
|
| 699 |
+
print(f"Error moving mask to device '{device}': {e}")
|
| 700 |
+
raise
|
| 701 |
+
|
| 702 |
+
|
| 703 |
+
# --- Global Variables for Loaded Model and Assets ---
|
| 704 |
+
inference_model = None
|
| 705 |
+
ALLOWED_MASK = None
|
| 706 |
+
model_device = "cpu"
|
| 707 |
+
config = None # Added global config
|
| 708 |
+
|
| 709 |
+
# --- Initialization Function ---
|
| 710 |
+
def initialize_model_and_tokenizer():
|
| 711 |
+
global inference_model, ALLOWED_MASK, model_device, token2id, config # Add config
|
| 712 |
+
|
| 713 |
+
print("Initializing model...")
|
| 714 |
+
# Determine device
|
| 715 |
+
model_device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 716 |
+
print(f"Using device: {model_device}")
|
| 717 |
+
|
| 718 |
+
# Load tokenizer data first (essential for vocab size)
|
| 719 |
+
try:
|
| 720 |
+
load_tokenizer_data(DICT_TXT_PATH)
|
| 721 |
+
if not token2id: # Check if loading actually populated the dict
|
| 722 |
+
raise ValueError("Tokenizer loading failed to populate token2id dictionary.")
|
| 723 |
+
except Exception as e:
|
| 724 |
+
print(f"FATAL: Could not load tokenizer data. Cannot proceed. Error: {e}")
|
| 725 |
+
return False # Indicate failure
|
| 726 |
+
|
| 727 |
+
# Define model config (MUST match finetuning config)
|
| 728 |
+
try:
|
| 729 |
+
# Ensure config is globally accessible after definition
|
| 730 |
+
config = GPTConfig(
|
| 731 |
+
vocab_size=len(token2id), # Get vocab size from loaded data
|
| 732 |
+
block_size=128, # Match training
|
| 733 |
+
n_layer=6, # Match training
|
| 734 |
+
n_head=6, # Match training
|
| 735 |
+
n_embd=384, # Match training
|
| 736 |
+
dropout=0.1, # Match training (dropout is off in eval mode)
|
| 737 |
+
bias=True # Match training
|
| 738 |
+
)
|
| 739 |
+
print(f"Model Config: {config}")
|
| 740 |
+
except Exception as e:
|
| 741 |
+
print(f"FATAL: Error creating GPTConfig: {e}")
|
| 742 |
+
return False
|
| 743 |
+
|
| 744 |
+
# Instantiate base and wrapped model (on CPU initially)
|
| 745 |
+
try:
|
| 746 |
+
base_gpt = GPT(config)
|
| 747 |
+
inference_model = GPTWithSoftPrompt(base_gpt, prompt_len=1)
|
| 748 |
+
except Exception as e:
|
| 749 |
+
print(f"FATAL: Error instantiating model: {e}")
|
| 750 |
+
traceback.print_exc()
|
| 751 |
+
return False
|
| 752 |
+
|
| 753 |
+
# Load finetuned weights
|
| 754 |
+
print(f"Loading finetuned weights from: {FINETUNED_MODEL_PATH}")
|
| 755 |
+
if not os.path.exists(FINETUNED_MODEL_PATH):
|
| 756 |
+
print(f"ERROR: Model weights file not found at {FINETUNED_MODEL_PATH}")
|
| 757 |
+
return False
|
| 758 |
+
|
| 759 |
+
try:
|
| 760 |
+
# Load state dict onto CPU first
|
| 761 |
+
state_dict = torch.load(FINETUNED_MODEL_PATH, map_location='cpu')
|
| 762 |
+
|
| 763 |
+
# Clean state dict keys (handle DDP 'module.' prefix)
|
| 764 |
+
cleaned_state_dict = {}
|
| 765 |
+
for k, v in state_dict.items():
|
| 766 |
+
name = k[7:] if k.startswith('module.') else k
|
| 767 |
+
cleaned_state_dict[name] = v
|
| 768 |
+
|
| 769 |
+
# Load into model
|
| 770 |
+
missing_keys, unexpected_keys = inference_model.load_state_dict(cleaned_state_dict, strict=False)
|
| 771 |
+
if missing_keys:
|
| 772 |
+
# Filter out non-persistent buffer keys if necessary (though strict=False should handle this)
|
| 773 |
+
missing_persistent = [k for k in missing_keys if inference_model.get_parameter(k) is not None or inference_model.get_buffer(k) is not None]
|
| 774 |
+
if missing_persistent:
|
| 775 |
+
print("Warning: Missing persistent keys during state dict load:", missing_persistent)
|
| 776 |
+
if unexpected_keys:
|
| 777 |
+
print("Warning: Unexpected keys during state dict load:", unexpected_keys)
|
| 778 |
+
print("Weights loaded successfully.")
|
| 779 |
+
|
| 780 |
+
except Exception as e:
|
| 781 |
+
print(f"Error loading state dict from {FINETUNED_MODEL_PATH}: {e}")
|
| 782 |
+
print("Ensure the model architecture matches the saved checkpoint and the file is not corrupted.")
|
| 783 |
+
traceback.print_exc()
|
| 784 |
+
return False
|
| 785 |
+
|
| 786 |
+
# Move model to target device and set to eval mode
|
| 787 |
+
try:
|
| 788 |
+
inference_model.to(model_device)
|
| 789 |
+
inference_model.eval()
|
| 790 |
+
print(f"Model moved to device: {model_device} and set to eval mode.")
|
| 791 |
+
except Exception as e:
|
| 792 |
+
print(f"Error moving model to device '{model_device}': {e}")
|
| 793 |
+
traceback.print_exc()
|
| 794 |
+
return False
|
| 795 |
+
|
| 796 |
+
|
| 797 |
+
# Build the allowed token mask (after model is on device)
|
| 798 |
+
print("Building allowed token mask...")
|
| 799 |
+
try:
|
| 800 |
+
if config.vocab_size <= 0:
|
| 801 |
+
raise ValueError("Vocabulary size must be positive to build mask.")
|
| 802 |
+
# Ensure model_device is valid before passing
|
| 803 |
+
device_obj = torch.device(model_device)
|
| 804 |
+
ALLOWED_MASK = build_allowed_token_mask(config.vocab_size, device_obj)
|
| 805 |
+
print("Allowed token mask created.")
|
| 806 |
+
except Exception as e:
|
| 807 |
+
print(f"ERROR: Failed to build allowed token mask: {e}")
|
| 808 |
+
traceback.print_exc()
|
| 809 |
+
return False
|
| 810 |
+
|
| 811 |
+
return True # Indicate success
|
| 812 |
+
|
| 813 |
+
|
| 814 |
+
# --- Inference Function ---
|
| 815 |
+
def predict_hallmarks(abstract_text):
|
| 816 |
+
global inference_model, ALLOWED_MASK, model_device, token2id, eos_id
|
| 817 |
+
|
| 818 |
+
# --- Pre-computation Checks ---
|
| 819 |
+
if inference_model is None:
|
| 820 |
+
print("Error: Inference model is not loaded.")
|
| 821 |
+
return ["Error: Model not loaded"]
|
| 822 |
+
if ALLOWED_MASK is None:
|
| 823 |
+
print("Error: Allowed mask is not built.")
|
| 824 |
+
return ["Error: Mask not built"]
|
| 825 |
+
if not token2id:
|
| 826 |
+
print("Error: Tokenizer vocabulary not loaded.")
|
| 827 |
+
return ["Error: Tokenizer not loaded"]
|
| 828 |
+
|
| 829 |
+
# --- Input Validation ---
|
| 830 |
+
if not isinstance(abstract_text, str):
|
| 831 |
+
print(f"Warning: Received non-string abstract text type: {type(abstract_text)}. Attempting conversion.")
|
| 832 |
+
try:
|
| 833 |
+
abstract_text = str(abstract_text)
|
| 834 |
+
except Exception:
|
| 835 |
+
return ["Error: Invalid input type"]
|
| 836 |
+
if not abstract_text.strip():
|
| 837 |
+
print("Warning: Received empty or whitespace-only abstract text.")
|
| 838 |
+
return [] # Return empty list for empty input
|
| 839 |
+
|
| 840 |
+
|
| 841 |
+
try:
|
| 842 |
+
# --- 1. Preprocess and Tokenize Input ---
|
| 843 |
+
print("Tokenizing input abstract...")
|
| 844 |
+
cleaned_abstract = " ".join(abstract_text.split())
|
| 845 |
+
if not cleaned_abstract:
|
| 846 |
+
print("Warning: Input abstract contains only whitespace after cleaning.")
|
| 847 |
+
return []
|
| 848 |
+
|
| 849 |
+
bpe_tokens_list = bpe_encode_lines([cleaned_abstract])
|
| 850 |
+
if not bpe_tokens_list or not bpe_tokens_list[0]: # Check if list or first element is empty
|
| 851 |
+
print("Warning: BPE encoding resulted in empty tokens.")
|
| 852 |
+
return []
|
| 853 |
+
bpe_tokens = bpe_tokens_list[0]
|
| 854 |
+
|
| 855 |
+
input_ids_list, oov = tokens_to_ids(bpe_tokens)
|
| 856 |
+
if oov > 0:
|
| 857 |
+
print(f"Info: Input contained {oov} OOV tokens.")
|
| 858 |
+
|
| 859 |
+
# Add EOS token
|
| 860 |
+
input_ids = input_ids_list + [eos_id]
|
| 861 |
+
|
| 862 |
+
# Convert to tensor and move to device
|
| 863 |
+
input_tensor = torch.tensor(input_ids, dtype=torch.long).unsqueeze(0).to(model_device)
|
| 864 |
+
|
| 865 |
+
# --- 2. Generate Predictions ---
|
| 866 |
+
print("Generating predictions...")
|
| 867 |
+
with torch.no_grad():
|
| 868 |
+
generated_ids_tensor = inference_model.generate_labels(
|
| 869 |
+
input_tensor,
|
| 870 |
+
allowed_mask=ALLOWED_MASK,
|
| 871 |
+
max_new_tokens=30,
|
| 872 |
+
temperature=0.0
|
| 873 |
+
)
|
| 874 |
+
|
| 875 |
+
# --- 3. Decode and Post-process ---
|
| 876 |
+
print("Decoding and cleaning predictions...")
|
| 877 |
+
if generated_ids_tensor is None or generated_ids_tensor.numel() == 0:
|
| 878 |
+
print("Warning: Generation resulted in empty tensor.")
|
| 879 |
+
generated_ids = []
|
| 880 |
+
else:
|
| 881 |
+
# Ensure tensor is on CPU before converting to list
|
| 882 |
+
generated_ids = generated_ids_tensor[0].cpu().tolist()
|
| 883 |
+
|
| 884 |
+
if not generated_ids:
|
| 885 |
+
print("No tokens generated.")
|
| 886 |
+
return []
|
| 887 |
+
|
| 888 |
+
generated_tokens = ids_to_tokens(generated_ids)
|
| 889 |
+
|
| 890 |
+
# Remove tokens after EOS if present
|
| 891 |
+
try:
|
| 892 |
+
eos_token_str = id2token.get(eos_id, "</s>") # Get string representation
|
| 893 |
+
if eos_token_str in generated_tokens:
|
| 894 |
+
eos_index = generated_tokens.index(eos_token_str)
|
| 895 |
+
generated_tokens = generated_tokens[:eos_index]
|
| 896 |
+
except ValueError:
|
| 897 |
+
pass # EOS not found is okay
|
| 898 |
+
|
| 899 |
+
# Decode BPE tokens to string
|
| 900 |
+
generated_text = bpe_decode_tokens(generated_tokens).strip().lower()
|
| 901 |
+
print(f"Raw generated text: '{generated_text}'")
|
| 902 |
+
|
| 903 |
+
# Split potential multi-labels and map to canonical
|
| 904 |
+
parts = []
|
| 905 |
+
if generated_text:
|
| 906 |
+
potential_parts = re.split(r'[;,|]\s*', generated_text)
|
| 907 |
+
parts = [p.strip() for p in potential_parts if p.strip()]
|
| 908 |
+
if not parts: # Handle case with no delimiters
|
| 909 |
+
parts = [generated_text]
|
| 910 |
+
|
| 911 |
+
predicted_labels = []
|
| 912 |
+
seen_labels = set()
|
| 913 |
+
for p in parts:
|
| 914 |
+
canonical_label = to_canonical(p)
|
| 915 |
+
if canonical_label and canonical_label not in seen_labels:
|
| 916 |
+
predicted_labels.append(canonical_label)
|
| 917 |
+
seen_labels.add(canonical_label)
|
| 918 |
+
|
| 919 |
+
print(f"Final predicted labels: {predicted_labels}")
|
| 920 |
+
return predicted_labels
|
| 921 |
+
|
| 922 |
+
# --- Error Handling ---
|
| 923 |
+
except FileNotFoundError as fnf_err:
|
| 924 |
+
print(f"ERROR during prediction (File Not Found - likely BPE related): {fnf_err}")
|
| 925 |
+
traceback.print_exc()
|
| 926 |
+
return ["Error: BPE file processing error"]
|
| 927 |
+
except PermissionError as perm_err:
|
| 928 |
+
print(f"ERROR during prediction (Permission Error - likely fastBPE): {perm_err}")
|
| 929 |
+
traceback.print_exc()
|
| 930 |
+
return ["Error: BPE execution permission"]
|
| 931 |
+
except RuntimeError as run_err:
|
| 932 |
+
if "CUDA out of memory" in str(run_err):
|
| 933 |
+
print(f"ERROR: CUDA Out of Memory during prediction. Input length: {len(input_ids) if 'input_ids' in locals() else 'N/A'}")
|
| 934 |
+
traceback.print_exc()
|
| 935 |
+
return ["Error: Input too long (OOM)"]
|
| 936 |
+
else:
|
| 937 |
+
print(f"ERROR during prediction (PyTorch RuntimeError): {run_err}")
|
| 938 |
+
traceback.print_exc()
|
| 939 |
+
return ["Error: Model runtime error"]
|
| 940 |
+
except Exception as e:
|
| 941 |
+
print(f"ERROR during prediction (General Exception): {e}")
|
| 942 |
+
traceback.print_exc()
|
| 943 |
+
return [f"Error: An unexpected error occurred"]
|
| 944 |
+
|
| 945 |
+
|
| 946 |
+
# --- Flask App ---
|
| 947 |
+
app = Flask(__name__)
|
| 948 |
+
|
| 949 |
+
# --- Load Model on Startup ---
|
| 950 |
+
model_initialized = False
|
| 951 |
+
|
| 952 |
+
@app.before_request
|
| 953 |
+
def ensure_model_loaded():
|
| 954 |
+
""" Ensures model is loaded before handling the first request. """
|
| 955 |
+
global model_initialized
|
| 956 |
+
if not model_initialized:
|
| 957 |
+
print("First request received, attempting to initialize model...")
|
| 958 |
+
# Add basic locking if deploying with multiple workers (though not fully thread-safe here)
|
| 959 |
+
# For true multi-worker safety, model loading should happen before workers fork.
|
| 960 |
+
try:
|
| 961 |
+
if initialize_model_and_tokenizer():
|
| 962 |
+
model_initialized = True
|
| 963 |
+
print("Model initialization successful.")
|
| 964 |
+
else:
|
| 965 |
+
print("FATAL: Model initialization failed during first request.")
|
| 966 |
+
# We won't raise an error here, but subsequent requests will fail until fixed.
|
| 967 |
+
except Exception as init_err:
|
| 968 |
+
print(f"FATAL: Exception during model initialization: {init_err}")
|
| 969 |
+
traceback.print_exc()
|
| 970 |
+
|
| 971 |
+
|
| 972 |
+
# --- Routes ---
|
| 973 |
+
@app.route('/')
|
| 974 |
+
def home():
|
| 975 |
+
""" Renders the HTML frontend page. """
|
| 976 |
+
# Check if initialization failed and show an error page if so?
|
| 977 |
+
# For simplicity, we assume initialization works or subsequent predict calls fail.
|
| 978 |
+
return render_template('index.html')
|
| 979 |
+
|
| 980 |
+
@app.route('/predict', methods=['POST'])
|
| 981 |
+
def predict():
|
| 982 |
+
""" Handles prediction requests from the frontend. """
|
| 983 |
+
global model_initialized
|
| 984 |
+
# Check if model is ready
|
| 985 |
+
if not model_initialized:
|
| 986 |
+
print("Error: Model not initialized when /predict called.")
|
| 987 |
+
# Return a specific status code like Service Unavailable
|
| 988 |
+
return jsonify({'error': 'Model is not ready. Please try again later.'}), 503
|
| 989 |
+
|
| 990 |
+
# Validate request format
|
| 991 |
+
if not request.is_json:
|
| 992 |
+
return jsonify({'error': 'Request must be JSON'}), 400
|
| 993 |
+
|
| 994 |
+
data = request.get_json()
|
| 995 |
+
abstract = data.get('abstract')
|
| 996 |
+
|
| 997 |
+
# Validate input abstract
|
| 998 |
+
if not abstract:
|
| 999 |
+
return jsonify({'error': 'Missing "abstract" field in JSON request'}), 400
|
| 1000 |
+
if not isinstance(abstract, str):
|
| 1001 |
+
return jsonify({'error': '"abstract" field must be a string'}), 400
|
| 1002 |
+
if len(abstract.strip()) == 0:
|
| 1003 |
+
print("Received empty abstract, returning empty prediction.")
|
| 1004 |
+
return jsonify({'predictions': []})
|
| 1005 |
+
MAX_ABSTRACT_LEN = 10000 # Define max length
|
| 1006 |
+
if len(abstract) > MAX_ABSTRACT_LEN:
|
| 1007 |
+
print(f"Received overly long abstract ({len(abstract)} chars), rejecting.")
|
| 1008 |
+
return jsonify({'error': f'Input abstract is too long (max {MAX_ABSTRACT_LEN} chars)'}), 413 # Payload Too Large
|
| 1009 |
+
|
| 1010 |
+
print(f"\n--- Received Prediction Request ---")
|
| 1011 |
+
print(f"Input Abstract (first 100 chars): {abstract[:100]}...")
|
| 1012 |
+
|
| 1013 |
+
try:
|
| 1014 |
+
# Perform prediction
|
| 1015 |
+
predictions = predict_hallmarks(abstract)
|
| 1016 |
+
print(f"--- Prediction Complete ---")
|
| 1017 |
+
|
| 1018 |
+
# Check if the result indicates an internal error occurred
|
| 1019 |
+
if isinstance(predictions, list) and len(predictions) > 0 and predictions[0].startswith("Error:"):
|
| 1020 |
+
print(f"Internal error during prediction: {predictions[0]}")
|
| 1021 |
+
# Return a generic server error to the client
|
| 1022 |
+
return jsonify({'error': 'An internal error occurred during prediction.'}), 500
|
| 1023 |
+
else:
|
| 1024 |
+
# Return successful predictions
|
| 1025 |
+
return jsonify({'predictions': predictions})
|
| 1026 |
+
|
| 1027 |
+
except Exception as e:
|
| 1028 |
+
# Catch unexpected errors in the route handler itself
|
| 1029 |
+
print(f"--- Prediction Failed Unexpectedly in Route ---")
|
| 1030 |
+
print(f"Error: {e}")
|
| 1031 |
+
traceback.print_exc()
|
| 1032 |
+
return jsonify({'error': 'An internal server error occurred.'}), 500
|
| 1033 |
+
|
| 1034 |
+
# --- Run the App ---
|
| 1035 |
+
if __name__ == '__main__':
|
| 1036 |
+
# Initialize model eagerly when running script directly
|
| 1037 |
+
if not model_initialized:
|
| 1038 |
+
print("Running script directly, initializing model eagerly...")
|
| 1039 |
+
if initialize_model_and_tokenizer():
|
| 1040 |
+
model_initialized = True
|
| 1041 |
+
print("Model initialization successful.")
|
| 1042 |
+
else:
|
| 1043 |
+
print("FATAL: Model initialization failed. Cannot start Flask server.")
|
| 1044 |
+
exit(1) # Exit if model fails to load on startup
|
| 1045 |
+
|
| 1046 |
+
# Check fastBPE path validity before starting server
|
| 1047 |
+
abs_fastbpe_path = os.path.abspath(FASTBPE_BIN_PATH)
|
| 1048 |
+
if not os.path.exists(abs_fastbpe_path):
|
| 1049 |
+
print(f"ERROR: fastBPE binary not found at '{abs_fastbpe_path}'.")
|
| 1050 |
+
print("Please ensure fastBPE is compiled and the path is correct relative to app.py.")
|
| 1051 |
+
exit(1)
|
| 1052 |
+
if not os.access(abs_fastbpe_path, os.X_OK):
|
| 1053 |
+
print(f"ERROR: fastBPE binary at '{abs_fastbpe_path}' is not executable.")
|
| 1054 |
+
print("Attempting to make it executable with 'chmod +x'...")
|
| 1055 |
+
try:
|
| 1056 |
+
os.chmod(abs_fastbpe_path, 0o755)
|
| 1057 |
+
print(f"Successfully made '{abs_fastbpe_path}' executable.")
|
| 1058 |
+
except OSError as e:
|
| 1059 |
+
print(f"ERROR: Failed to make fastBPE executable: {e}")
|
| 1060 |
+
print("Please set execute permissions manually (e.g., 'chmod +x ./fast').")
|
| 1061 |
+
exit(1)
|
| 1062 |
+
|
| 1063 |
+
print("Starting Flask server...")
|
| 1064 |
+
# Use host='0.0.0.0' to make it accessible on your network
|
| 1065 |
+
# Set debug=False for production environments
|
| 1066 |
+
app.run(host='0.0.0.0', port=5000, debug=False) # Changed debug to False
|
| 1067 |
+
|
bpecodes
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
dict.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
fast
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c86a282547b0bc4d088adee84b45ad031db1b23139d944a159ae8f9cfce3186e
|
| 3 |
+
size 132336
|
hoc_best.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:13d02e9a35bd2fc935f5f1291a16b568928e223cba9152f5284ef45755f91491
|
| 3 |
+
size 107913359
|
requirements.txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Flask
|
| 2 |
+
torch
|
| 3 |
+
sacremoses==0.0.53
|
| 4 |
+
numpy
|
| 5 |
+
pandas
|
| 6 |
+
tqdm
|
| 7 |
+
waitress
|
templates/index.html
ADDED
|
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!DOCTYPE html>
|
| 2 |
+
<html lang="en">
|
| 3 |
+
<head>
|
| 4 |
+
<meta charset="UTF-8">
|
| 5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
| 6 |
+
<title>BioGPT HoC Inference</title>
|
| 7 |
+
<!-- Load Tailwind CSS -->
|
| 8 |
+
<script src="https://cdn.tailwindcss.com"></script>
|
| 9 |
+
<!-- Include Inter font -->
|
| 10 |
+
<link rel="preconnect" href="https://fonts.googleapis.com">
|
| 11 |
+
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
|
| 12 |
+
<link href="https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&display=swap" rel="stylesheet">
|
| 13 |
+
<style>
|
| 14 |
+
body {
|
| 15 |
+
font-family: 'Inter', sans-serif;
|
| 16 |
+
}
|
| 17 |
+
.hallmark-badge {
|
| 18 |
+
display: inline-block;
|
| 19 |
+
padding: 0.3rem 0.6rem;
|
| 20 |
+
margin: 0.2rem;
|
| 21 |
+
border-radius: 0.5rem; /* Rounded corners */
|
| 22 |
+
font-size: 0.8rem;
|
| 23 |
+
line-height: 1rem;
|
| 24 |
+
font-weight: 500;
|
| 25 |
+
cursor: default;
|
| 26 |
+
transition: background-color 0.2s ease-in-out;
|
| 27 |
+
border: 1px solid transparent;
|
| 28 |
+
}
|
| 29 |
+
/* Color definitions (same as before) */
|
| 30 |
+
.hallmark-color-0 { background-color: #EFF6FF; color: #1E40AF; border-color: #BEE3F8; } /* Blue */
|
| 31 |
+
.hallmark-color-1 { background-color: #ECFDF5; color: #065F46; border-color: #A7F3D0; } /* Green */
|
| 32 |
+
.hallmark-color-2 { background-color: #FFFBEB; color: #92400E; border-color: #FDE68A; } /* Yellow */
|
| 33 |
+
.hallmark-color-3 { background-color: #FEF2F2; color: #991B1B; border-color: #FECACA; } /* Red */
|
| 34 |
+
.hallmark-color-4 { background-color: #F5F3FF; color: #5B21B6; border-color: #DDD6FE; } /* Purple */
|
| 35 |
+
.hallmark-color-5 { background-color: #FDF2F8; color: #9D174D; border-color: #FBCFE8; } /* Pink */
|
| 36 |
+
.hallmark-color-6 { background-color: #EEF2FF; color: #3730A3; border-color: #C7D2FE; } /* Indigo */
|
| 37 |
+
.hallmark-color-7 { background-color: #F0FDFA; color: #134E4A; border-color: #99F6E4; } /* Teal */
|
| 38 |
+
.hallmark-color-8 { background-color: #FFF7ED; color: #9A3412; border-color: #FED7AA; } /* Orange */
|
| 39 |
+
.hallmark-color-9 { background-color: #F9FAFB; color: #374151; border-color: #E5E7EB; } /* Gray */
|
| 40 |
+
|
| 41 |
+
/* Loading Spinner */
|
| 42 |
+
.spinner {
|
| 43 |
+
border: 4px solid rgba(0, 0, 0, 0.1);
|
| 44 |
+
width: 36px;
|
| 45 |
+
height: 36px;
|
| 46 |
+
border-radius: 50%;
|
| 47 |
+
border-left-color: #4f46e5; /* Indigo */
|
| 48 |
+
animation: spin 1s ease infinite;
|
| 49 |
+
margin: 20px auto;
|
| 50 |
+
}
|
| 51 |
+
@keyframes spin {
|
| 52 |
+
0% { transform: rotate(0deg); }
|
| 53 |
+
100% { transform: rotate(360deg); }
|
| 54 |
+
}
|
| 55 |
+
</style>
|
| 56 |
+
</head>
|
| 57 |
+
<body class="bg-gray-100 min-h-screen flex flex-col items-center justify-center p-4 sm:p-6">
|
| 58 |
+
|
| 59 |
+
<div class="bg-white p-6 sm:p-8 rounded-lg shadow-xl w-full max-w-2xl space-y-6">
|
| 60 |
+
|
| 61 |
+
<!-- Header -->
|
| 62 |
+
<div class="text-center border-b pb-4">
|
| 63 |
+
<h1 class="text-2xl sm:text-3xl font-bold text-gray-800 mb-1">BioGPT HoC Inference</h1>
|
| 64 |
+
<h2 class="text-md sm:text-lg font-semibold text-blue-700 mb-2">~27M Parameter Finetuned SLM</h2>
|
| 65 |
+
<p class="text-sm text-gray-600 max-w-2xl mx-auto">
|
| 66 |
+
Enter a biomedical abstract below to predict Hallmarks of Cancer labels using the finetuned model.
|
| 67 |
+
</p>
|
| 68 |
+
</div>
|
| 69 |
+
|
| 70 |
+
<!-- Input Area -->
|
| 71 |
+
<div>
|
| 72 |
+
<label for="abstract-input" class="block text-sm font-medium text-gray-700 mb-1">Enter Abstract Text:</label>
|
| 73 |
+
<textarea id="abstract-input" rows="8" class="w-full p-3 border border-gray-300 rounded-lg focus:ring-2 focus:ring-blue-500 focus:border-transparent transition duration-150 ease-in-out shadow-sm" placeholder="Paste or type abstract here..."></textarea>
|
| 74 |
+
</div>
|
| 75 |
+
|
| 76 |
+
<!-- Predict Button -->
|
| 77 |
+
<div class="text-center">
|
| 78 |
+
<button id="predict-button" class="bg-blue-600 hover:bg-blue-700 text-white font-semibold py-2.5 px-8 rounded-lg transition duration-150 ease-in-out shadow-md hover:shadow-lg focus:outline-none focus:ring-2 focus:ring-blue-500 focus:ring-offset-2 disabled:opacity-50 disabled:cursor-not-allowed">
|
| 79 |
+
Predict Hallmarks
|
| 80 |
+
</button>
|
| 81 |
+
</div>
|
| 82 |
+
|
| 83 |
+
<!-- Results Area -->
|
| 84 |
+
<div id="results-area" class="space-y-3 hidden">
|
| 85 |
+
<div class="border border-green-200 p-4 rounded-lg bg-green-50 shadow-sm min-h-[80px]">
|
| 86 |
+
<h3 class="text-md font-semibold text-green-800 mb-2">Predicted Hallmarks:</h3>
|
| 87 |
+
<div id="loading-indicator" class="hidden text-center">
|
| 88 |
+
<div class="spinner"></div>
|
| 89 |
+
<p class="text-sm text-gray-600">Predicting...</p>
|
| 90 |
+
</div>
|
| 91 |
+
<div id="error-message" class="hidden text-center text-red-600 font-medium">
|
| 92 |
+
An error occurred. Please try again.
|
| 93 |
+
</div>
|
| 94 |
+
<div id="prediction-output">
|
| 95 |
+
<!-- Badges will be inserted here -->
|
| 96 |
+
</div>
|
| 97 |
+
</div>
|
| 98 |
+
</div>
|
| 99 |
+
|
| 100 |
+
</div>
|
| 101 |
+
|
| 102 |
+
<script>
|
| 103 |
+
// --- Hallmark Colors (for consistent styling) ---
|
| 104 |
+
const hallmarkList = [ // Ensure this matches the list in app.py
|
| 105 |
+
"activating invasion and metastasis", "avoiding immune destruction",
|
| 106 |
+
"cellular energetics", "enabling replicative immortality",
|
| 107 |
+
"evading growth suppressors", "genomic instability and mutation",
|
| 108 |
+
"inducing angiogenesis", "resisting cell death",
|
| 109 |
+
"sustaining proliferative signaling", "tumor promoting inflammation",
|
| 110 |
+
];
|
| 111 |
+
const hallmarkColors = {};
|
| 112 |
+
hallmarkList.forEach((hallmark, index) => {
|
| 113 |
+
hallmarkColors[hallmark] = `hallmark-color-${index % 10}`;
|
| 114 |
+
});
|
| 115 |
+
|
| 116 |
+
// --- DOM Elements ---
|
| 117 |
+
const abstractInput = document.getElementById('abstract-input');
|
| 118 |
+
const predictButton = document.getElementById('predict-button');
|
| 119 |
+
const resultsArea = document.getElementById('results-area');
|
| 120 |
+
const loadingIndicator = document.getElementById('loading-indicator');
|
| 121 |
+
const errorMessage = document.getElementById('error-message');
|
| 122 |
+
const predictionOutput = document.getElementById('prediction-output');
|
| 123 |
+
|
| 124 |
+
// --- Event Listener ---
|
| 125 |
+
predictButton.addEventListener('click', handlePrediction);
|
| 126 |
+
|
| 127 |
+
// --- Prediction Logic ---
|
| 128 |
+
async function handlePrediction() {
|
| 129 |
+
const abstractText = abstractInput.value.trim();
|
| 130 |
+
if (!abstractText) {
|
| 131 |
+
alert("Please enter some abstract text.");
|
| 132 |
+
return;
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
// --- UI Updates: Start loading ---
|
| 136 |
+
predictButton.disabled = true;
|
| 137 |
+
resultsArea.classList.remove('hidden');
|
| 138 |
+
predictionOutput.innerHTML = ''; // Clear previous results
|
| 139 |
+
errorMessage.classList.add('hidden');
|
| 140 |
+
loadingIndicator.classList.remove('hidden');
|
| 141 |
+
|
| 142 |
+
try {
|
| 143 |
+
// --- Call the backend API ---
|
| 144 |
+
const response = await fetch('/predict', {
|
| 145 |
+
method: 'POST',
|
| 146 |
+
headers: {
|
| 147 |
+
'Content-Type': 'application/json',
|
| 148 |
+
},
|
| 149 |
+
body: JSON.stringify({ abstract: abstractText }),
|
| 150 |
+
});
|
| 151 |
+
|
| 152 |
+
// --- Handle Response ---
|
| 153 |
+
loadingIndicator.classList.add('hidden'); // Hide loading indicator
|
| 154 |
+
if (!response.ok) {
|
| 155 |
+
const errorData = await response.json();
|
| 156 |
+
throw new Error(errorData.error || `HTTP error! status: ${response.status}`);
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
const data = await response.json();
|
| 160 |
+
displayPredictions(data.predictions);
|
| 161 |
+
|
| 162 |
+
} catch (error) {
|
| 163 |
+
console.error('Prediction failed:', error);
|
| 164 |
+
loadingIndicator.classList.add('hidden');
|
| 165 |
+
errorMessage.textContent = `Prediction failed: ${error.message}`;
|
| 166 |
+
errorMessage.classList.remove('hidden');
|
| 167 |
+
} finally {
|
| 168 |
+
predictButton.disabled = false; // Re-enable button
|
| 169 |
+
}
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
// --- Display Logic ---
|
| 173 |
+
function displayPredictions(predictions) {
|
| 174 |
+
predictionOutput.innerHTML = ''; // Clear previous just in case
|
| 175 |
+
errorMessage.classList.add('hidden');
|
| 176 |
+
|
| 177 |
+
if (predictions && predictions.length > 0) {
|
| 178 |
+
predictions.forEach(label => {
|
| 179 |
+
const badge = document.createElement('span');
|
| 180 |
+
badge.textContent = label;
|
| 181 |
+
const colorClass = hallmarkColors[label] || 'hallmark-color-9'; // Default color
|
| 182 |
+
badge.className = `hallmark-badge ${colorClass}`;
|
| 183 |
+
predictionOutput.appendChild(badge);
|
| 184 |
+
});
|
| 185 |
+
} else {
|
| 186 |
+
predictionOutput.innerHTML = '<span class="text-gray-500 italic text-sm">No specific hallmarks predicted for this abstract.</span>';
|
| 187 |
+
}
|
| 188 |
+
}
|
| 189 |
+
</script>
|
| 190 |
+
</body>
|
| 191 |
+
</html>
|