Priyansu19 commited on
Commit
b347ca3
·
1 Parent(s): 7e1bc94

Add all BioGPT app files with LFS tracking

Browse files
Files changed (8) hide show
  1. Dockerfile +36 -0
  2. app.py +1067 -0
  3. bpecodes +0 -0
  4. dict.txt +0 -0
  5. fast +3 -0
  6. hoc_best.pt +3 -0
  7. requirements.txt +7 -0
  8. templates/index.html +191 -0
Dockerfile ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use an official Python runtime as a parent image (e.g., 3.11 or 3.12)
2
+ FROM python:3.11-slim
3
+
4
+ # Set the working directory in the container
5
+ WORKDIR /app
6
+
7
+ # Install system dependencies if needed (e.g., build-essential if compiling fastBPE)
8
+ # RUN apt-get update && apt-get install -y --no-install-recommends build-essential && rm -rf /var/lib/apt/lists/*
9
+
10
+ # Copy the requirements file into the container first (for Docker caching)
11
+ COPY requirements.txt requirements.txt
12
+
13
+ # Install Python dependencies
14
+ RUN pip install --no-cache-dir -r requirements.txt
15
+
16
+ # Copy the application code and necessary files into the container
17
+ # IMPORTANT: This assumes 'fast' is pre-compiled for Linux x86_64
18
+ COPY app.py .
19
+ COPY hoc_best.pt .
20
+ COPY bpecodes .
21
+ COPY dict.txt .
22
+ COPY fast . # Copy the 'fast' executable directly
23
+ COPY templates/ ./templates/
24
+
25
+ # Ensure the 'fast' binary is executable within the container
26
+ RUN chmod +x ./fast
27
+
28
+ # Make port 5000 available (as defined in README.md and app.py)
29
+ EXPOSE 5000
30
+
31
+ # Define environment variable for Python output buffering
32
+ ENV PYTHONUNBUFFERED=1
33
+
34
+ # Run the app using waitress (a production-ready WSGI server)
35
+ # Assumes your Flask app instance is named 'app' in 'app.py'
36
+ CMD ["waitress-serve", "--host=0.0.0.0", "--port=5000", "app:app"]
app.py ADDED
@@ -0,0 +1,1067 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ import math
4
+ import difflib
5
+ import tempfile
6
+ import numpy as np
7
+ import pandas as pd
8
+ from tqdm.auto import tqdm
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from dataclasses import dataclass
13
+ from sacremoses import MosesDetokenizer
14
+ from flask import Flask, request, jsonify, render_template
15
+ import traceback # Import traceback for better error logging
16
+ import re # Import the regular expression module
17
+
18
+ # --- Constants and Paths ---
19
+ # Ensure these files are in the same directory as app.py or provide correct paths
20
+ FINETUNED_MODEL_PATH = "hoc_best.pt"
21
+ BPE_CODES_PATH = "bpecodes"
22
+ DICT_TXT_PATH = "dict.txt"
23
+ FASTBPE_BIN_PATH = "./fast" # Assumes fast executable is alongside app.py
24
+
25
+ HALLMARKS = [ # Keep this consistent with training/evaluation
26
+ "activating invasion and metastasis", "avoiding immune destruction",
27
+ "cellular energetics", "enabling replicative immortality",
28
+ "evading growth suppressors", "genomic instability and mutation",
29
+ "inducing angiogenesis", "resisting cell death",
30
+ "sustaining proliferative signaling", "tumor promoting inflammation",
31
+ ]
32
+
33
+ # --- Model Architecture Definitions (Copy from your notebook) ---
34
+ # NOTE: Make sure these classes are IDENTICAL to the ones used for training
35
+ # including GPTConfig, LayerNorm, CausalSelfAttention, MLP, Block, GPT, GPTWithSoftPrompt
36
+
37
+ @dataclass
38
+ class GPTConfig:
39
+ block_size: int
40
+ vocab_size: int
41
+ n_layer: int
42
+ n_head: int
43
+ n_embd: int
44
+ dropout: float = 0.0
45
+ bias: bool = True
46
+
47
+ class LayerNorm(nn.Module):
48
+ # (Copied from notebook)
49
+ def __init__(self, ndim, bias):
50
+ super().__init__()
51
+ self.weight = nn.Parameter(torch.ones(ndim))
52
+ self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
53
+ def forward(self, x):
54
+ return F.layer_norm(x, self.weight.shape, self.weight, self.bias, 1e-5)
55
+
56
+ class CausalSelfAttention(nn.Module):
57
+ # (Copied from notebook)
58
+ def __init__(self, config):
59
+ super().__init__()
60
+ assert config.n_embd % config.n_head == 0
61
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
62
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
63
+ self.attn_dropout = nn.Dropout(config.dropout)
64
+ self.resid_dropout = nn.Dropout(config.dropout)
65
+ self.n_head = config.n_head
66
+ self.n_embd = config.n_embd
67
+ self.flash = hasattr(F, 'scaled_dot_product_attention') # Check for flash attention
68
+ if not self.flash:
69
+ # print("Warning: Flash Attention not available.") # Optional warning
70
+ # Make the buffer persistent otherwise device mismatches during forward pass
71
+ self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
72
+ .view(1, 1, config.block_size, config.block_size), persistent=True)
73
+ #else:
74
+ # print("Using Flash Attention.") # Optional info
75
+
76
+ def forward(self, x):
77
+ B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
78
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
79
+ q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
80
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
81
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
82
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
83
+
84
+ # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
85
+ if self.flash:
86
+ # efficient attention using Flash Attention CUDA kernels
87
+ y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.attn_dropout.p if self.training else 0.0, is_causal=True)
88
+ else:
89
+ # manual implementation of attention
90
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
91
+ # Ensure bias buffer is used correctly
92
+ # Check if bias buffer exists before using it
93
+ if hasattr(self, 'bias'):
94
+ att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
95
+ else:
96
+ # Fallback if somehow bias wasn't registered (shouldn't happen with persistent=True)
97
+ mask = torch.tril(torch.ones(T, T, device=x.device)).view(1, 1, T, T)
98
+ att = att.masked_fill(mask == 0, float('-inf'))
99
+
100
+ att = F.softmax(att, dim=-1)
101
+ att = self.attn_dropout(att)
102
+ y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
103
+
104
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
105
+ # output projection
106
+ y = self.resid_dropout(self.c_proj(y))
107
+ return y
108
+
109
+
110
+ class MLP(nn.Module):
111
+ # (Copied from notebook)
112
+ def __init__(self, config):
113
+ super().__init__()
114
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
115
+ self.gelu = nn.GELU()
116
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
117
+ self.dropout = nn.Dropout(config.dropout)
118
+ def forward(self, x):
119
+ x = self.c_fc(x)
120
+ x = self.gelu(x)
121
+ x = self.c_proj(x)
122
+ x = self.dropout(x)
123
+ return x
124
+
125
+ class Block(nn.Module):
126
+ # (Copied from notebook)
127
+ def __init__(self, config):
128
+ super().__init__()
129
+ self.ln1 = LayerNorm(config.n_embd, bias=config.bias)
130
+ self.attn = CausalSelfAttention(config)
131
+ self.ln2 = LayerNorm(config.n_embd, bias=config.bias)
132
+ self.mlp = MLP(config)
133
+ def forward(self, x):
134
+ x = x + self.attn(self.ln1(x))
135
+ x = x + self.mlp(self.ln2(x))
136
+ return x
137
+
138
+ class GPT(nn.Module):
139
+ # (Copied from notebook - simplified _init_weights and removed generate)
140
+ def __init__(self, config):
141
+ super().__init__()
142
+ #assert config.vocab_size is not None
143
+ #assert config.block_size is not None
144
+ self.config = config
145
+
146
+ self.transformer = nn.ModuleDict(dict(
147
+ wte = nn.Embedding(config.vocab_size, config.n_embd),
148
+ wpe = nn.Embedding(config.block_size, config.n_embd),
149
+ drop = nn.Dropout(config.dropout),
150
+ h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
151
+ ln_f = LayerNorm(config.n_embd, bias=config.bias),
152
+ ))
153
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
154
+ self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying
155
+
156
+ # init all weights
157
+ self.apply(self._init_weights)
158
+ # apply special scaled init to the residual projections, per GPT-2 paper
159
+ for pn, p in self.named_parameters():
160
+ if pn.endswith('c_proj.weight'):
161
+ torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
162
+
163
+ def _init_weights(self, module):
164
+ if isinstance(module, nn.Linear):
165
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
166
+ if module.bias is not None:
167
+ torch.nn.init.zeros_(module.bias)
168
+ elif isinstance(module, nn.Embedding):
169
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
170
+
171
+ def forward(self, idx, targets=None):
172
+ device = idx.device
173
+ b, t = idx.size()
174
+ if t > self.config.block_size:
175
+ # Crop sequence if longer than block size
176
+ print(f"Warning: Input sequence length ({t}) > block size ({self.config.block_size}). Cropping.")
177
+ idx = idx[:, -self.config.block_size:]
178
+ t = self.config.block_size
179
+ #assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
180
+ pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t)
181
+
182
+ # forward the GPT model itself
183
+ tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
184
+ pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
185
+ x = self.transformer.drop(tok_emb + pos_emb)
186
+ for block in self.transformer.h:
187
+ x = block(x)
188
+ x = self.transformer.ln_f(x)
189
+
190
+ if targets is not None:
191
+ # if we are given some desired targets also calculate the loss
192
+ logits = self.lm_head(x)
193
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
194
+ return logits, loss
195
+ else:
196
+ # inference-time mini-optimization: only forward the lm_head on the very last position
197
+ logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
198
+ # Check for NaN/Inf in logits before returning
199
+ if torch.isnan(logits).any() or torch.isinf(logits).any():
200
+ print("WARNING: NaN or Inf detected in logits during inference.")
201
+ # Handle appropriately - maybe return an error indicator or zero logits?
202
+ # For now, just print warning.
203
+ return logits, None
204
+
205
+ class GPTWithSoftPrompt(nn.Module):
206
+ # (Copied from notebook - simplified)
207
+ def __init__(self, base_gpt: GPT, prompt_len=1):
208
+ super().__init__()
209
+ self.config = base_gpt.config
210
+ self.transformer = base_gpt.transformer
211
+ self.lm_head = base_gpt.lm_head
212
+ C = self.config.n_embd
213
+ self.soft_prompt = nn.Parameter(torch.zeros(1, prompt_len, C)) # Keep on CPU first
214
+ nn.init.normal_(self.soft_prompt, mean=0.0, std=0.02)
215
+
216
+ def forward(self, idx, targets=None):
217
+ B, T = idx.shape
218
+ device = idx.device # Get device from input tensor
219
+
220
+ # Make sure soft_prompt is on the same device as input
221
+ soft_prompt_on_device = self.soft_prompt.to(device)
222
+
223
+ # token + pos
224
+ tok_emb = self.transformer.wte(idx) # (B,T,C)
225
+ pos = torch.arange(0, T, dtype=torch.long, device=device)
226
+ pos_emb = self.transformer.wpe(pos) # (T,C)
227
+ x_tokens = tok_emb + pos_emb
228
+
229
+ # prepend soft prompt
230
+ soft = soft_prompt_on_device.expand(B, -1, -1) # (B,P,C)
231
+
232
+ # --- FIX: Define P before the if/else block ---
233
+ P = soft.size(1) # Get soft prompt length
234
+
235
+ x = torch.cat([soft, x_tokens], dim=1) # (B,P+T,C)
236
+
237
+ # --- Standard Transformer forward pass ---
238
+ x = self.transformer.drop(x)
239
+ for block in self.transformer.h:
240
+ x = block(x)
241
+ x = self.transformer.ln_f(x)
242
+ logits = self.lm_head(x) # (B,P+T,V)
243
+ # --- End Standard ---
244
+
245
+ if targets is None:
246
+ # Inference: return logits for the last token of the *original* sequence
247
+ # We need the prediction *after* the last input token, which is at index T (P+T-1 overall)
248
+ # Use P which is now defined
249
+ # Ensure index is within bounds
250
+ target_logit_index = P + T - 1
251
+ if target_logit_index >= logits.size(1):
252
+ print(f"Warning: Calculated logit index {target_logit_index} out of bounds for logits shape {logits.shape}. Returning last logit.")
253
+ target_logit_index = -1 # Fallback to last logit
254
+
255
+ final_logits = logits[:, target_logit_index, :]
256
+ # Check for NaN/Inf
257
+ if torch.isnan(final_logits).any() or torch.isinf(final_logits).any():
258
+ print(f"WARNING: NaN or Inf detected in final_logits at index {target_logit_index}.")
259
+ # Handle appropriately - maybe return zeros or raise an error?
260
+ # For now, just print warning. Let the calling function handle it.
261
+
262
+ return final_logits, None # Return (B, V)
263
+ else:
264
+ # Training loss calculation (copied from notebook)
265
+ # P is already defined above
266
+ pad_ignore = torch.full((B, P), -1, dtype=targets.dtype, device=device)
267
+ full_targets = torch.cat([pad_ignore, targets], dim=1)
268
+ logits_lm = logits[:, :-1, :].contiguous()
269
+ targets_lm = full_targets[:, 1:].contiguous()
270
+ loss = F.cross_entropy(
271
+ logits_lm.view(-1, logits_lm.size(-1)),
272
+ targets_lm.view(-1),
273
+ ignore_index=-1
274
+ )
275
+ # Check for NaN/Inf in loss
276
+ if torch.isnan(loss) or torch.isinf(loss):
277
+ print("WARNING: NaN or Inf detected in loss calculation.")
278
+ # Potentially add debugging info here (e.g., print shapes, inputs)
279
+ return logits, loss
280
+
281
+ # --- Constrained generation method (from Section 2.9) ---
282
+ @torch.no_grad()
283
+ def generate_labels(self, idx, allowed_mask, max_new_tokens=24, temperature=0.0):
284
+ self.eval() # Ensure model is in eval mode
285
+ B = idx.size(0)
286
+ # Add soft prompt length to effective block size consideration
287
+ P = self.soft_prompt.size(1)
288
+ # Correct effective block size based on GPT class logic
289
+ effective_block_size = self.config.block_size # GPT forward handles cropping
290
+
291
+ # Start with input index
292
+ out = idx.clone() # Clone to avoid modifying original input
293
+
294
+ # Ensure allowed_mask is on the correct device
295
+ allowed_mask = allowed_mask.to(idx.device)
296
+
297
+ finished = torch.zeros(B, dtype=torch.bool, device=idx.device)
298
+
299
+ # Get global eos_id safely
300
+ global eos_id
301
+ current_eos_id = eos_id # Use the globally loaded eos_id
302
+
303
+ for step in range(max_new_tokens):
304
+ # Crop context if it exceeds model's block size (GPT forward handles this internally now)
305
+ # ctx = out if out.size(1) <= effective_block_size else out[:, -effective_block_size:]
306
+ ctx = out # Pass the current sequence
307
+
308
+ # Forward pass - expects shape (B, T), model handles soft prompt internally
309
+ # It returns logits for the *next* token prediction after the last token in ctx
310
+ logits, _ = self(ctx) # Gets logits for last token prediction, shape (B, V)
311
+
312
+ # Check for NaN/Inf in logits
313
+ if torch.isnan(logits).any() or torch.isinf(logits).any():
314
+ print(f"WARNING: NaN or Inf detected in logits during generation step {step}. Stopping generation.")
315
+ # Return what we have so far, excluding potentially bad last token
316
+ return out[:, idx.size(1):] # Or handle error differently
317
+
318
+ # Apply constraint mask
319
+ # Ensure mask shape matches logits shape
320
+ if logits.shape != allowed_mask.shape:
321
+ print(f"Warning: Logits shape {logits.shape} doesn't match mask shape {allowed_mask.shape}. Reshaping mask.")
322
+ # This assumes mask needs batch dim added
323
+ current_mask = allowed_mask.unsqueeze(0).expand_as(logits)
324
+ else:
325
+ current_mask = allowed_mask
326
+
327
+ logits = logits + current_mask
328
+
329
+ # Sample next token
330
+ if temperature <= 0:
331
+ # Greedy decoding
332
+ next_id = torch.argmax(logits, dim=-1) # (B,)
333
+ else:
334
+ # Temperature sampling
335
+ probs = F.softmax(logits / temperature, dim=-1)
336
+ # Check for NaN/Inf in probs
337
+ if torch.isnan(probs).any() or torch.isinf(probs).any():
338
+ print(f"WARNING: NaN or Inf detected in probabilities during generation step {step}. Using argmax fallback.")
339
+ next_id = torch.argmax(logits, dim=-1) # Fallback to greedy
340
+ else:
341
+ try:
342
+ next_id = torch.multinomial(probs, num_samples=1).squeeze(1) # (B,)
343
+ except RuntimeError as e:
344
+ print(f"WARNING: torch.multinomial failed: {e}. Using argmax fallback.")
345
+ next_id = torch.argmax(logits, dim=-1) # Fallback to greedy
346
+
347
+
348
+ # Handle finished sequences (force EOS) and update output
349
+ # Check if current_eos_id is valid
350
+ if not isinstance(current_eos_id, int):
351
+ print(f"Warning: Global eos_id is not an integer ({current_eos_id}). Defaulting to 0.")
352
+ current_eos_id = 0
353
+ next_id = next_id.masked_fill(finished, current_eos_id) # Use the validated eos_id
354
+
355
+ # Check if next_id contains invalid values (e.g., negative)
356
+ if (next_id < 0).any():
357
+ print(f"WARNING: Negative token ID generated: {next_id}. Clipping to 0.")
358
+ next_id = torch.clamp(next_id, min=0)
359
+
360
+
361
+ # Append the next token ID
362
+ out = torch.cat([out, next_id.unsqueeze(1)], dim=1)
363
+
364
+ # Update finished status
365
+ finished |= (next_id == current_eos_id)
366
+
367
+ # Stop if all sequences in the batch are finished
368
+ if bool(finished.all()):
369
+ # print(f"Generation finished early at step {step+1}") # Optional debug info
370
+ break
371
+ # else:
372
+ # print(f"Generation reached max_new_tokens ({max_new_tokens})") # Optional debug info
373
+
374
+ # Return only the generated part (excluding the initial idx length)
375
+ return out[:, idx.size(1):]
376
+
377
+
378
+ # --- Tokenizer Helper Functions ---
379
+ # Added robustness and error checks
380
+ # Global tokenizer maps and special IDs, loaded once at startup
381
+ token2id, id2token = {}, {}
382
+ eos_id = 0 # Default, will be overwritten
383
+ pad_id = 0 # Default, will be overwritten
384
+ detokenizer = None
385
+
386
+ def load_tokenizer_data(dict_path):
387
+ global token2id, id2token, eos_id, pad_id, detokenizer
388
+ print(f"Loading vocabulary from {dict_path}...")
389
+ local_token2id, local_id2token = {}, {}
390
+ try:
391
+ with open(dict_path, encoding="utf-8") as f:
392
+ for i, line in enumerate(f):
393
+ parts = line.split() # Split by whitespace
394
+ if not parts: continue # Skip empty lines
395
+ tok = parts[0]
396
+ if tok in local_token2id:
397
+ print(f"Warning: Duplicate token '{tok}' found at line {i+1}. Keeping first occurrence.")
398
+ continue
399
+ local_token2id[tok] = i
400
+ local_id2token[i] = tok
401
+
402
+ # Assign to global variables only after successful loading
403
+ token2id = local_token2id
404
+ id2token = local_id2token
405
+
406
+ # Use a known special token ID if </s> is missing, otherwise default might be wrong
407
+ # Try multiple common EOS tokens
408
+ possible_eos = ["</s>", "<|endoftext|>", "[EOS]"]
409
+ found_eos = False
410
+ for eos_tok in possible_eos:
411
+ if eos_tok in token2id:
412
+ eos_id = token2id[eos_tok]
413
+ found_eos = True
414
+ print(f"Found EOS token '{eos_tok}' with ID: {eos_id}")
415
+ break
416
+ if not found_eos:
417
+ # If no common EOS found, fall back to the highest index or 0
418
+ eos_id = max(token2id.values()) if token2id else 0
419
+ print(f"Warning: Standard EOS tokens not found. Using highest index ({eos_id}) as EOS ID.")
420
+
421
+ # Assign pad_id, often same as eos_id or a specific <pad> token
422
+ pad_id = token2id.get("<pad>", eos_id) # Prefer <pad> if exists, else use eos_id
423
+ print(f"Using PAD ID: {pad_id}")
424
+
425
+ detokenizer = MosesDetokenizer(lang='en') # Initialize once
426
+ print(f"Vocabulary loaded. Size: {len(token2id)}")
427
+ if not detokenizer:
428
+ raise ValueError("MosesDetokenizer failed to initialize.")
429
+
430
+ except FileNotFoundError:
431
+ print(f"ERROR: Vocabulary file not found at {dict_path}")
432
+ raise
433
+ except Exception as e:
434
+ print(f"ERROR: Failed to load tokenizer data from {dict_path}: {e}")
435
+ raise
436
+
437
+ def bpe_encode_lines(lines, shard_size=500, desc="BPE Encode"):
438
+ """ Encodes lines using external fastBPE binary. Added error checking. """
439
+ global BPE_CODES_PATH, FASTBPE_BIN_PATH
440
+ # --- Input Validation ---
441
+ if not isinstance(lines, list):
442
+ print(f"Warning: bpe_encode_lines expected a list, got {type(lines)}. Attempting conversion.")
443
+ try:
444
+ lines = list(lines)
445
+ except TypeError:
446
+ raise ValueError("Input 'lines' must be a list or convertible to a list.")
447
+
448
+ if not lines: return []
449
+
450
+ # --- Path and Executable Checks ---
451
+ abs_fastbpe_path = os.path.abspath(FASTBPE_BIN_PATH)
452
+ abs_bpe_codes_path = os.path.abspath(BPE_CODES_PATH)
453
+
454
+ if not os.path.exists(abs_fastbpe_path):
455
+ raise FileNotFoundError(f"fastBPE executable not found at {abs_fastbpe_path}")
456
+ if not os.path.exists(abs_bpe_codes_path):
457
+ raise FileNotFoundError(f"BPE codes file not found at {abs_bpe_codes_path}")
458
+ if not os.access(abs_fastbpe_path, os.X_OK):
459
+ print(f"Warning: fastBPE binary at {abs_fastbpe_path} is not executable. Attempting chmod...")
460
+ try:
461
+ os.chmod(abs_fastbpe_path, 0o755)
462
+ except OSError as e:
463
+ raise PermissionError(f"Failed to make fastBPE executable: {e}. Please check permissions.")
464
+
465
+ out_tokens = []
466
+ # Process in chunks
467
+ with tempfile.TemporaryDirectory() as td:
468
+ for start in range(0, len(lines), shard_size):
469
+ chunk = lines[start:start+shard_size]
470
+ src_path = os.path.join(td, f"src_{start}.txt")
471
+ dst_path = os.path.join(td, f"dst_{start}.bpe")
472
+
473
+ try:
474
+ # Write chunk to temp file, ensuring strings
475
+ with open(src_path, "w", encoding="utf-8") as f:
476
+ for s in chunk:
477
+ f.write(str(s or "").strip() + "\n") # Ensure string conversion
478
+
479
+ # Run fastBPE
480
+ cmd = [abs_fastbpe_path, "applybpe", dst_path, src_path, abs_bpe_codes_path]
481
+ # print(f"Running command: {' '.join(cmd)}") # Debug command
482
+ process = subprocess.run(
483
+ cmd,
484
+ capture_output=True, text=True, check=False # Don't check=True here, handle error below
485
+ )
486
+
487
+ # Check for errors specifically
488
+ if process.returncode != 0:
489
+ # Log more details on failure
490
+ print(f"ERROR: fastBPE failed (exit code {process.returncode}) on chunk starting at index {start}.")
491
+ print(f"Command: {' '.join(cmd)}")
492
+ print(f"Stderr:\n{process.stderr}")
493
+ # Optionally print some input data
494
+ print(f"First line of input chunk: {chunk[0] if chunk else 'N/A'}")
495
+ raise subprocess.CalledProcessError(process.returncode, cmd, output=process.stdout, stderr=process.stderr)
496
+
497
+ # Read results if successful
498
+ with open(dst_path, "r", encoding="utf-8") as f:
499
+ for line in f:
500
+ out_tokens.append(line.strip().split())
501
+
502
+ except subprocess.CalledProcessError as e:
503
+ # Handle specific subprocess errors (already printed details)
504
+ raise # Re-raise to stop execution
505
+ except Exception as e:
506
+ print(f"ERROR: Unexpected error during BPE encoding chunk starting at index {start}: {e}")
507
+ traceback.print_exc() # Print full traceback for unexpected errors
508
+ raise # Re-raise
509
+ return out_tokens
510
+
511
+
512
+ def tokens_to_ids(bpe_tokens):
513
+ """ Converts BPE token strings to IDs using the global map. Added checks. """
514
+ global token2id, pad_id
515
+ if not isinstance(bpe_tokens, list):
516
+ raise ValueError(f"Input 'bpe_tokens' must be a list, got {type(bpe_tokens)}.")
517
+
518
+ ids = []
519
+ oov_count = 0
520
+ for t in bpe_tokens:
521
+ if not isinstance(t, str):
522
+ print(f"Warning: Non-string token found in bpe_tokens: {t}. Using pad_id.")
523
+ ids.append(pad_id)
524
+ oov_count += 1
525
+ continue
526
+
527
+ id_val = token2id.get(t, pad_id)
528
+ ids.append(id_val)
529
+ if id_val == pad_id and t not in token2id:
530
+ oov_count += 1
531
+ # print(f"Warning: OOV token '{t}' mapped to pad_id {pad_id}") # Reduce noise
532
+ if oov_count > 0:
533
+ print(f"Info: Found {oov_count} OOV tokens in sequence of length {len(bpe_tokens)}.")
534
+ return ids, oov_count
535
+
536
+ def ids_to_tokens(ids):
537
+ """ Converts IDs back to token strings. Added checks. """
538
+ global id2token
539
+ if not isinstance(ids, list):
540
+ raise ValueError(f"Input 'ids' must be a list, got {type(ids)}.")
541
+
542
+ tokens = []
543
+ for i in ids:
544
+ # Ensure ID is a valid integer before lookup
545
+ try:
546
+ # Handle potential floats or NaNs from generation
547
+ if isinstance(i, float) and math.isnan(i):
548
+ token = "<nan>"
549
+ else:
550
+ int_i = int(i)
551
+ token = id2token.get(int_i, "<unk>")
552
+ except (ValueError, TypeError):
553
+ print(f"Warning: Could not convert ID '{i}' to int. Using '<unk>'.")
554
+ token = "<unk>"
555
+ tokens.append(token)
556
+ return tokens
557
+
558
+
559
+ def bpe_decode_tokens(bpe_tokens):
560
+ """ Converts BPE token strings back to readable text. Added checks. """
561
+ global detokenizer
562
+ if detokenizer is None:
563
+ raise RuntimeError("Detokenizer not initialized. Call load_tokenizer_data first.")
564
+ if not isinstance(bpe_tokens, list):
565
+ raise ValueError(f"Input 'bpe_tokens' must be a list, got {type(bpe_tokens)}.")
566
+
567
+ # Ensure all items are strings before joining
568
+ try:
569
+ str_tokens = [str(t) for t in bpe_tokens]
570
+ except Exception as e:
571
+ print(f"Error converting tokens to strings: {e}. Tokens: {bpe_tokens}")
572
+ return "<decoding error>"
573
+
574
+ s = ' '.join(str_tokens).replace('@@ ', '')
575
+ try:
576
+ # Detokenizer might fail on empty or unusual input
577
+ return detokenizer.detokenize(s.split()) if s.strip() else ""
578
+ except Exception as e:
579
+ print(f"Error during detokenization: {e}. Input string: '{s}'")
580
+ return "<detokenization error>"
581
+
582
+
583
+ # --- Prediction Helper Functions ---
584
+
585
+ def to_canonical(pred_chunk: str):
586
+ """ Maps a predicted text chunk to a canonical hallmark name. Added checks. """
587
+ global HALLMARKS
588
+ # Ensure input is a string
589
+ if not isinstance(pred_chunk, str):
590
+ # print(f"Warning: to_canonical received non-string input: {pred_chunk}. Returning None.")
591
+ return None
592
+
593
+ s = pred_chunk.strip().lower()
594
+ low = [L.lower() for L in HALLMARKS]
595
+ if not s: return None
596
+
597
+ if s in low:
598
+ return HALLMARKS[low.index(s)]
599
+
600
+ # Use difflib for fuzzy matching
601
+ try:
602
+ best = difflib.get_close_matches(s, low, n=1, cutoff=0.7)
603
+ return HALLMARKS[low.index(best[0])] if best else None
604
+ except Exception as e:
605
+ print(f"Error during difflib matching for '{s}': {e}")
606
+ return None # Return None on error
607
+
608
+ def build_allowed_token_mask(vocab_size, device):
609
+ """ Builds the mask for constrained decoding. Added error checks. """
610
+ global HALLMARKS, token2id, eos_id, pad_id
611
+ allowed = set()
612
+
613
+ # --- Input Validation ---
614
+ if vocab_size <= 0:
615
+ raise ValueError("Vocabulary size must be positive.")
616
+ if not token2id:
617
+ raise RuntimeError("Tokenizer vocabulary (token2id) not loaded.")
618
+
619
+ print("Encoding hallmarks for mask...")
620
+ try:
621
+ # Ensure HALLMARKS is a list of strings
622
+ if not isinstance(HALLMARKS, list) or not all(isinstance(h, str) for h in HALLMARKS):
623
+ raise ValueError("HALLMARKS must be a list of strings.")
624
+ hallmark_bpes = bpe_encode_lines(HALLMARKS, desc="BPE Hallmarks (for mask)")
625
+ for bpe_list in hallmark_bpes:
626
+ ids, _ = tokens_to_ids(bpe_list)
627
+ allowed.update(ids)
628
+ print(f"Encoded {len(HALLMARKS)} hallmarks.")
629
+ except Exception as e:
630
+ print(f"ERROR: Failed to BPE encode or convert hallmarks for mask: {e}")
631
+ raise
632
+
633
+ print("Encoding separators for mask...")
634
+ SEPS = [", ", ",", "; ", ";", "|"]
635
+ try:
636
+ sep_bpes = bpe_encode_lines(SEPS, desc="BPE Separators (for mask)")
637
+ for bpe_list in sep_bpes:
638
+ ids, _ = tokens_to_ids(bpe_list)
639
+ allowed.update(ids)
640
+ print(f"Encoded {len(SEPS)} separators.")
641
+ except Exception as e:
642
+ print(f"ERROR: Failed to BPE encode or convert separators for mask: {e}")
643
+ raise
644
+
645
+ # Add EOS token - Check if eos_id is valid
646
+ if not isinstance(eos_id, int) or eos_id < 0 or eos_id >= vocab_size:
647
+ print(f"Warning: Invalid EOS ID ({eos_id}). Defaulting mask EOS to 0.")
648
+ effective_eos_id = 0
649
+ else:
650
+ effective_eos_id = eos_id
651
+ allowed.add(effective_eos_id)
652
+ print(f"Total allowed token IDs (including EOS {effective_eos_id}): {len(allowed)}")
653
+
654
+ # Create the mask tensor on CPU first
655
+ mask = torch.full((vocab_size,), float('-inf'), device=torch.device('cpu'))
656
+ try:
657
+ # Filter out potential invalid IDs before creating list for indexing
658
+ # Ensure pad_id is valid if used for filtering
659
+ effective_pad_id = pad_id if isinstance(pad_id, int) and 0 <= pad_id < vocab_size else -1 # Use -1 if pad_id is invalid
660
+
661
+ valid_allowed_ids = []
662
+ for id_ in allowed:
663
+ if isinstance(id_, int) and 0 <= id_ < vocab_size: # Check type and range
664
+ # Filter out pad_id unless it's the same as the effective_eos_id
665
+ if id_ != effective_pad_id or id_ == effective_eos_id:
666
+ valid_allowed_ids.append(id_)
667
+ # else: print(f"Warning: Invalid ID {id_} in allowed set skipped.") # Reduce noise
668
+
669
+ if not valid_allowed_ids:
670
+ raise ValueError("No valid token IDs found to allow in the mask.")
671
+
672
+ # Check ranges again after filtering (belt and braces)
673
+ max_valid_id = max(valid_allowed_ids)
674
+ min_valid_id = min(valid_allowed_ids)
675
+ if max_valid_id >= vocab_size or min_valid_id < 0:
676
+ # This should ideally not happen if filtering worked
677
+ raise IndexError(f"Filtered allowed IDs still out of range [{min_valid_id}, {max_valid_id}] for vocab size {vocab_size}.")
678
+
679
+ # Apply mask
680
+ mask[valid_allowed_ids] = 0.0 # Use list directly
681
+ print(f"Mask created with {len(valid_allowed_ids)} allowed indices.")
682
+
683
+ except IndexError as e:
684
+ print(f"ERROR: Index error while creating mask. Vocab size: {vocab_size}. Error: {e}")
685
+ # Find problematic IDs more carefully
686
+ problem_ids = [i for i in allowed if not isinstance(i, int) or i < 0 or i >= vocab_size]
687
+ print(f"Problematic IDs in allowed set: {problem_ids}")
688
+ raise
689
+ except Exception as e:
690
+ print(f"ERROR: Unexpected error creating mask: {e}")
691
+ traceback.print_exc()
692
+ raise
693
+
694
+ # Move final mask to target device
695
+ try:
696
+ target_device = torch.device(device) # Ensure device is a torch.device object
697
+ return mask.to(target_device)
698
+ except Exception as e:
699
+ print(f"Error moving mask to device '{device}': {e}")
700
+ raise
701
+
702
+
703
+ # --- Global Variables for Loaded Model and Assets ---
704
+ inference_model = None
705
+ ALLOWED_MASK = None
706
+ model_device = "cpu"
707
+ config = None # Added global config
708
+
709
+ # --- Initialization Function ---
710
+ def initialize_model_and_tokenizer():
711
+ global inference_model, ALLOWED_MASK, model_device, token2id, config # Add config
712
+
713
+ print("Initializing model...")
714
+ # Determine device
715
+ model_device = "cuda" if torch.cuda.is_available() else "cpu"
716
+ print(f"Using device: {model_device}")
717
+
718
+ # Load tokenizer data first (essential for vocab size)
719
+ try:
720
+ load_tokenizer_data(DICT_TXT_PATH)
721
+ if not token2id: # Check if loading actually populated the dict
722
+ raise ValueError("Tokenizer loading failed to populate token2id dictionary.")
723
+ except Exception as e:
724
+ print(f"FATAL: Could not load tokenizer data. Cannot proceed. Error: {e}")
725
+ return False # Indicate failure
726
+
727
+ # Define model config (MUST match finetuning config)
728
+ try:
729
+ # Ensure config is globally accessible after definition
730
+ config = GPTConfig(
731
+ vocab_size=len(token2id), # Get vocab size from loaded data
732
+ block_size=128, # Match training
733
+ n_layer=6, # Match training
734
+ n_head=6, # Match training
735
+ n_embd=384, # Match training
736
+ dropout=0.1, # Match training (dropout is off in eval mode)
737
+ bias=True # Match training
738
+ )
739
+ print(f"Model Config: {config}")
740
+ except Exception as e:
741
+ print(f"FATAL: Error creating GPTConfig: {e}")
742
+ return False
743
+
744
+ # Instantiate base and wrapped model (on CPU initially)
745
+ try:
746
+ base_gpt = GPT(config)
747
+ inference_model = GPTWithSoftPrompt(base_gpt, prompt_len=1)
748
+ except Exception as e:
749
+ print(f"FATAL: Error instantiating model: {e}")
750
+ traceback.print_exc()
751
+ return False
752
+
753
+ # Load finetuned weights
754
+ print(f"Loading finetuned weights from: {FINETUNED_MODEL_PATH}")
755
+ if not os.path.exists(FINETUNED_MODEL_PATH):
756
+ print(f"ERROR: Model weights file not found at {FINETUNED_MODEL_PATH}")
757
+ return False
758
+
759
+ try:
760
+ # Load state dict onto CPU first
761
+ state_dict = torch.load(FINETUNED_MODEL_PATH, map_location='cpu')
762
+
763
+ # Clean state dict keys (handle DDP 'module.' prefix)
764
+ cleaned_state_dict = {}
765
+ for k, v in state_dict.items():
766
+ name = k[7:] if k.startswith('module.') else k
767
+ cleaned_state_dict[name] = v
768
+
769
+ # Load into model
770
+ missing_keys, unexpected_keys = inference_model.load_state_dict(cleaned_state_dict, strict=False)
771
+ if missing_keys:
772
+ # Filter out non-persistent buffer keys if necessary (though strict=False should handle this)
773
+ missing_persistent = [k for k in missing_keys if inference_model.get_parameter(k) is not None or inference_model.get_buffer(k) is not None]
774
+ if missing_persistent:
775
+ print("Warning: Missing persistent keys during state dict load:", missing_persistent)
776
+ if unexpected_keys:
777
+ print("Warning: Unexpected keys during state dict load:", unexpected_keys)
778
+ print("Weights loaded successfully.")
779
+
780
+ except Exception as e:
781
+ print(f"Error loading state dict from {FINETUNED_MODEL_PATH}: {e}")
782
+ print("Ensure the model architecture matches the saved checkpoint and the file is not corrupted.")
783
+ traceback.print_exc()
784
+ return False
785
+
786
+ # Move model to target device and set to eval mode
787
+ try:
788
+ inference_model.to(model_device)
789
+ inference_model.eval()
790
+ print(f"Model moved to device: {model_device} and set to eval mode.")
791
+ except Exception as e:
792
+ print(f"Error moving model to device '{model_device}': {e}")
793
+ traceback.print_exc()
794
+ return False
795
+
796
+
797
+ # Build the allowed token mask (after model is on device)
798
+ print("Building allowed token mask...")
799
+ try:
800
+ if config.vocab_size <= 0:
801
+ raise ValueError("Vocabulary size must be positive to build mask.")
802
+ # Ensure model_device is valid before passing
803
+ device_obj = torch.device(model_device)
804
+ ALLOWED_MASK = build_allowed_token_mask(config.vocab_size, device_obj)
805
+ print("Allowed token mask created.")
806
+ except Exception as e:
807
+ print(f"ERROR: Failed to build allowed token mask: {e}")
808
+ traceback.print_exc()
809
+ return False
810
+
811
+ return True # Indicate success
812
+
813
+
814
+ # --- Inference Function ---
815
+ def predict_hallmarks(abstract_text):
816
+ global inference_model, ALLOWED_MASK, model_device, token2id, eos_id
817
+
818
+ # --- Pre-computation Checks ---
819
+ if inference_model is None:
820
+ print("Error: Inference model is not loaded.")
821
+ return ["Error: Model not loaded"]
822
+ if ALLOWED_MASK is None:
823
+ print("Error: Allowed mask is not built.")
824
+ return ["Error: Mask not built"]
825
+ if not token2id:
826
+ print("Error: Tokenizer vocabulary not loaded.")
827
+ return ["Error: Tokenizer not loaded"]
828
+
829
+ # --- Input Validation ---
830
+ if not isinstance(abstract_text, str):
831
+ print(f"Warning: Received non-string abstract text type: {type(abstract_text)}. Attempting conversion.")
832
+ try:
833
+ abstract_text = str(abstract_text)
834
+ except Exception:
835
+ return ["Error: Invalid input type"]
836
+ if not abstract_text.strip():
837
+ print("Warning: Received empty or whitespace-only abstract text.")
838
+ return [] # Return empty list for empty input
839
+
840
+
841
+ try:
842
+ # --- 1. Preprocess and Tokenize Input ---
843
+ print("Tokenizing input abstract...")
844
+ cleaned_abstract = " ".join(abstract_text.split())
845
+ if not cleaned_abstract:
846
+ print("Warning: Input abstract contains only whitespace after cleaning.")
847
+ return []
848
+
849
+ bpe_tokens_list = bpe_encode_lines([cleaned_abstract])
850
+ if not bpe_tokens_list or not bpe_tokens_list[0]: # Check if list or first element is empty
851
+ print("Warning: BPE encoding resulted in empty tokens.")
852
+ return []
853
+ bpe_tokens = bpe_tokens_list[0]
854
+
855
+ input_ids_list, oov = tokens_to_ids(bpe_tokens)
856
+ if oov > 0:
857
+ print(f"Info: Input contained {oov} OOV tokens.")
858
+
859
+ # Add EOS token
860
+ input_ids = input_ids_list + [eos_id]
861
+
862
+ # Convert to tensor and move to device
863
+ input_tensor = torch.tensor(input_ids, dtype=torch.long).unsqueeze(0).to(model_device)
864
+
865
+ # --- 2. Generate Predictions ---
866
+ print("Generating predictions...")
867
+ with torch.no_grad():
868
+ generated_ids_tensor = inference_model.generate_labels(
869
+ input_tensor,
870
+ allowed_mask=ALLOWED_MASK,
871
+ max_new_tokens=30,
872
+ temperature=0.0
873
+ )
874
+
875
+ # --- 3. Decode and Post-process ---
876
+ print("Decoding and cleaning predictions...")
877
+ if generated_ids_tensor is None or generated_ids_tensor.numel() == 0:
878
+ print("Warning: Generation resulted in empty tensor.")
879
+ generated_ids = []
880
+ else:
881
+ # Ensure tensor is on CPU before converting to list
882
+ generated_ids = generated_ids_tensor[0].cpu().tolist()
883
+
884
+ if not generated_ids:
885
+ print("No tokens generated.")
886
+ return []
887
+
888
+ generated_tokens = ids_to_tokens(generated_ids)
889
+
890
+ # Remove tokens after EOS if present
891
+ try:
892
+ eos_token_str = id2token.get(eos_id, "</s>") # Get string representation
893
+ if eos_token_str in generated_tokens:
894
+ eos_index = generated_tokens.index(eos_token_str)
895
+ generated_tokens = generated_tokens[:eos_index]
896
+ except ValueError:
897
+ pass # EOS not found is okay
898
+
899
+ # Decode BPE tokens to string
900
+ generated_text = bpe_decode_tokens(generated_tokens).strip().lower()
901
+ print(f"Raw generated text: '{generated_text}'")
902
+
903
+ # Split potential multi-labels and map to canonical
904
+ parts = []
905
+ if generated_text:
906
+ potential_parts = re.split(r'[;,|]\s*', generated_text)
907
+ parts = [p.strip() for p in potential_parts if p.strip()]
908
+ if not parts: # Handle case with no delimiters
909
+ parts = [generated_text]
910
+
911
+ predicted_labels = []
912
+ seen_labels = set()
913
+ for p in parts:
914
+ canonical_label = to_canonical(p)
915
+ if canonical_label and canonical_label not in seen_labels:
916
+ predicted_labels.append(canonical_label)
917
+ seen_labels.add(canonical_label)
918
+
919
+ print(f"Final predicted labels: {predicted_labels}")
920
+ return predicted_labels
921
+
922
+ # --- Error Handling ---
923
+ except FileNotFoundError as fnf_err:
924
+ print(f"ERROR during prediction (File Not Found - likely BPE related): {fnf_err}")
925
+ traceback.print_exc()
926
+ return ["Error: BPE file processing error"]
927
+ except PermissionError as perm_err:
928
+ print(f"ERROR during prediction (Permission Error - likely fastBPE): {perm_err}")
929
+ traceback.print_exc()
930
+ return ["Error: BPE execution permission"]
931
+ except RuntimeError as run_err:
932
+ if "CUDA out of memory" in str(run_err):
933
+ print(f"ERROR: CUDA Out of Memory during prediction. Input length: {len(input_ids) if 'input_ids' in locals() else 'N/A'}")
934
+ traceback.print_exc()
935
+ return ["Error: Input too long (OOM)"]
936
+ else:
937
+ print(f"ERROR during prediction (PyTorch RuntimeError): {run_err}")
938
+ traceback.print_exc()
939
+ return ["Error: Model runtime error"]
940
+ except Exception as e:
941
+ print(f"ERROR during prediction (General Exception): {e}")
942
+ traceback.print_exc()
943
+ return [f"Error: An unexpected error occurred"]
944
+
945
+
946
+ # --- Flask App ---
947
+ app = Flask(__name__)
948
+
949
+ # --- Load Model on Startup ---
950
+ model_initialized = False
951
+
952
+ @app.before_request
953
+ def ensure_model_loaded():
954
+ """ Ensures model is loaded before handling the first request. """
955
+ global model_initialized
956
+ if not model_initialized:
957
+ print("First request received, attempting to initialize model...")
958
+ # Add basic locking if deploying with multiple workers (though not fully thread-safe here)
959
+ # For true multi-worker safety, model loading should happen before workers fork.
960
+ try:
961
+ if initialize_model_and_tokenizer():
962
+ model_initialized = True
963
+ print("Model initialization successful.")
964
+ else:
965
+ print("FATAL: Model initialization failed during first request.")
966
+ # We won't raise an error here, but subsequent requests will fail until fixed.
967
+ except Exception as init_err:
968
+ print(f"FATAL: Exception during model initialization: {init_err}")
969
+ traceback.print_exc()
970
+
971
+
972
+ # --- Routes ---
973
+ @app.route('/')
974
+ def home():
975
+ """ Renders the HTML frontend page. """
976
+ # Check if initialization failed and show an error page if so?
977
+ # For simplicity, we assume initialization works or subsequent predict calls fail.
978
+ return render_template('index.html')
979
+
980
+ @app.route('/predict', methods=['POST'])
981
+ def predict():
982
+ """ Handles prediction requests from the frontend. """
983
+ global model_initialized
984
+ # Check if model is ready
985
+ if not model_initialized:
986
+ print("Error: Model not initialized when /predict called.")
987
+ # Return a specific status code like Service Unavailable
988
+ return jsonify({'error': 'Model is not ready. Please try again later.'}), 503
989
+
990
+ # Validate request format
991
+ if not request.is_json:
992
+ return jsonify({'error': 'Request must be JSON'}), 400
993
+
994
+ data = request.get_json()
995
+ abstract = data.get('abstract')
996
+
997
+ # Validate input abstract
998
+ if not abstract:
999
+ return jsonify({'error': 'Missing "abstract" field in JSON request'}), 400
1000
+ if not isinstance(abstract, str):
1001
+ return jsonify({'error': '"abstract" field must be a string'}), 400
1002
+ if len(abstract.strip()) == 0:
1003
+ print("Received empty abstract, returning empty prediction.")
1004
+ return jsonify({'predictions': []})
1005
+ MAX_ABSTRACT_LEN = 10000 # Define max length
1006
+ if len(abstract) > MAX_ABSTRACT_LEN:
1007
+ print(f"Received overly long abstract ({len(abstract)} chars), rejecting.")
1008
+ return jsonify({'error': f'Input abstract is too long (max {MAX_ABSTRACT_LEN} chars)'}), 413 # Payload Too Large
1009
+
1010
+ print(f"\n--- Received Prediction Request ---")
1011
+ print(f"Input Abstract (first 100 chars): {abstract[:100]}...")
1012
+
1013
+ try:
1014
+ # Perform prediction
1015
+ predictions = predict_hallmarks(abstract)
1016
+ print(f"--- Prediction Complete ---")
1017
+
1018
+ # Check if the result indicates an internal error occurred
1019
+ if isinstance(predictions, list) and len(predictions) > 0 and predictions[0].startswith("Error:"):
1020
+ print(f"Internal error during prediction: {predictions[0]}")
1021
+ # Return a generic server error to the client
1022
+ return jsonify({'error': 'An internal error occurred during prediction.'}), 500
1023
+ else:
1024
+ # Return successful predictions
1025
+ return jsonify({'predictions': predictions})
1026
+
1027
+ except Exception as e:
1028
+ # Catch unexpected errors in the route handler itself
1029
+ print(f"--- Prediction Failed Unexpectedly in Route ---")
1030
+ print(f"Error: {e}")
1031
+ traceback.print_exc()
1032
+ return jsonify({'error': 'An internal server error occurred.'}), 500
1033
+
1034
+ # --- Run the App ---
1035
+ if __name__ == '__main__':
1036
+ # Initialize model eagerly when running script directly
1037
+ if not model_initialized:
1038
+ print("Running script directly, initializing model eagerly...")
1039
+ if initialize_model_and_tokenizer():
1040
+ model_initialized = True
1041
+ print("Model initialization successful.")
1042
+ else:
1043
+ print("FATAL: Model initialization failed. Cannot start Flask server.")
1044
+ exit(1) # Exit if model fails to load on startup
1045
+
1046
+ # Check fastBPE path validity before starting server
1047
+ abs_fastbpe_path = os.path.abspath(FASTBPE_BIN_PATH)
1048
+ if not os.path.exists(abs_fastbpe_path):
1049
+ print(f"ERROR: fastBPE binary not found at '{abs_fastbpe_path}'.")
1050
+ print("Please ensure fastBPE is compiled and the path is correct relative to app.py.")
1051
+ exit(1)
1052
+ if not os.access(abs_fastbpe_path, os.X_OK):
1053
+ print(f"ERROR: fastBPE binary at '{abs_fastbpe_path}' is not executable.")
1054
+ print("Attempting to make it executable with 'chmod +x'...")
1055
+ try:
1056
+ os.chmod(abs_fastbpe_path, 0o755)
1057
+ print(f"Successfully made '{abs_fastbpe_path}' executable.")
1058
+ except OSError as e:
1059
+ print(f"ERROR: Failed to make fastBPE executable: {e}")
1060
+ print("Please set execute permissions manually (e.g., 'chmod +x ./fast').")
1061
+ exit(1)
1062
+
1063
+ print("Starting Flask server...")
1064
+ # Use host='0.0.0.0' to make it accessible on your network
1065
+ # Set debug=False for production environments
1066
+ app.run(host='0.0.0.0', port=5000, debug=False) # Changed debug to False
1067
+
bpecodes ADDED
The diff for this file is too large to render. See raw diff
 
dict.txt ADDED
The diff for this file is too large to render. See raw diff
 
fast ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c86a282547b0bc4d088adee84b45ad031db1b23139d944a159ae8f9cfce3186e
3
+ size 132336
hoc_best.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:13d02e9a35bd2fc935f5f1291a16b568928e223cba9152f5284ef45755f91491
3
+ size 107913359
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ Flask
2
+ torch
3
+ sacremoses==0.0.53
4
+ numpy
5
+ pandas
6
+ tqdm
7
+ waitress
templates/index.html ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>BioGPT HoC Inference</title>
7
+ <!-- Load Tailwind CSS -->
8
+ <script src="https://cdn.tailwindcss.com"></script>
9
+ <!-- Include Inter font -->
10
+ <link rel="preconnect" href="https://fonts.googleapis.com">
11
+ <link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
12
+ <link href="https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&display=swap" rel="stylesheet">
13
+ <style>
14
+ body {
15
+ font-family: 'Inter', sans-serif;
16
+ }
17
+ .hallmark-badge {
18
+ display: inline-block;
19
+ padding: 0.3rem 0.6rem;
20
+ margin: 0.2rem;
21
+ border-radius: 0.5rem; /* Rounded corners */
22
+ font-size: 0.8rem;
23
+ line-height: 1rem;
24
+ font-weight: 500;
25
+ cursor: default;
26
+ transition: background-color 0.2s ease-in-out;
27
+ border: 1px solid transparent;
28
+ }
29
+ /* Color definitions (same as before) */
30
+ .hallmark-color-0 { background-color: #EFF6FF; color: #1E40AF; border-color: #BEE3F8; } /* Blue */
31
+ .hallmark-color-1 { background-color: #ECFDF5; color: #065F46; border-color: #A7F3D0; } /* Green */
32
+ .hallmark-color-2 { background-color: #FFFBEB; color: #92400E; border-color: #FDE68A; } /* Yellow */
33
+ .hallmark-color-3 { background-color: #FEF2F2; color: #991B1B; border-color: #FECACA; } /* Red */
34
+ .hallmark-color-4 { background-color: #F5F3FF; color: #5B21B6; border-color: #DDD6FE; } /* Purple */
35
+ .hallmark-color-5 { background-color: #FDF2F8; color: #9D174D; border-color: #FBCFE8; } /* Pink */
36
+ .hallmark-color-6 { background-color: #EEF2FF; color: #3730A3; border-color: #C7D2FE; } /* Indigo */
37
+ .hallmark-color-7 { background-color: #F0FDFA; color: #134E4A; border-color: #99F6E4; } /* Teal */
38
+ .hallmark-color-8 { background-color: #FFF7ED; color: #9A3412; border-color: #FED7AA; } /* Orange */
39
+ .hallmark-color-9 { background-color: #F9FAFB; color: #374151; border-color: #E5E7EB; } /* Gray */
40
+
41
+ /* Loading Spinner */
42
+ .spinner {
43
+ border: 4px solid rgba(0, 0, 0, 0.1);
44
+ width: 36px;
45
+ height: 36px;
46
+ border-radius: 50%;
47
+ border-left-color: #4f46e5; /* Indigo */
48
+ animation: spin 1s ease infinite;
49
+ margin: 20px auto;
50
+ }
51
+ @keyframes spin {
52
+ 0% { transform: rotate(0deg); }
53
+ 100% { transform: rotate(360deg); }
54
+ }
55
+ </style>
56
+ </head>
57
+ <body class="bg-gray-100 min-h-screen flex flex-col items-center justify-center p-4 sm:p-6">
58
+
59
+ <div class="bg-white p-6 sm:p-8 rounded-lg shadow-xl w-full max-w-2xl space-y-6">
60
+
61
+ <!-- Header -->
62
+ <div class="text-center border-b pb-4">
63
+ <h1 class="text-2xl sm:text-3xl font-bold text-gray-800 mb-1">BioGPT HoC Inference</h1>
64
+ <h2 class="text-md sm:text-lg font-semibold text-blue-700 mb-2">~27M Parameter Finetuned SLM</h2>
65
+ <p class="text-sm text-gray-600 max-w-2xl mx-auto">
66
+ Enter a biomedical abstract below to predict Hallmarks of Cancer labels using the finetuned model.
67
+ </p>
68
+ </div>
69
+
70
+ <!-- Input Area -->
71
+ <div>
72
+ <label for="abstract-input" class="block text-sm font-medium text-gray-700 mb-1">Enter Abstract Text:</label>
73
+ <textarea id="abstract-input" rows="8" class="w-full p-3 border border-gray-300 rounded-lg focus:ring-2 focus:ring-blue-500 focus:border-transparent transition duration-150 ease-in-out shadow-sm" placeholder="Paste or type abstract here..."></textarea>
74
+ </div>
75
+
76
+ <!-- Predict Button -->
77
+ <div class="text-center">
78
+ <button id="predict-button" class="bg-blue-600 hover:bg-blue-700 text-white font-semibold py-2.5 px-8 rounded-lg transition duration-150 ease-in-out shadow-md hover:shadow-lg focus:outline-none focus:ring-2 focus:ring-blue-500 focus:ring-offset-2 disabled:opacity-50 disabled:cursor-not-allowed">
79
+ Predict Hallmarks
80
+ </button>
81
+ </div>
82
+
83
+ <!-- Results Area -->
84
+ <div id="results-area" class="space-y-3 hidden">
85
+ <div class="border border-green-200 p-4 rounded-lg bg-green-50 shadow-sm min-h-[80px]">
86
+ <h3 class="text-md font-semibold text-green-800 mb-2">Predicted Hallmarks:</h3>
87
+ <div id="loading-indicator" class="hidden text-center">
88
+ <div class="spinner"></div>
89
+ <p class="text-sm text-gray-600">Predicting...</p>
90
+ </div>
91
+ <div id="error-message" class="hidden text-center text-red-600 font-medium">
92
+ An error occurred. Please try again.
93
+ </div>
94
+ <div id="prediction-output">
95
+ <!-- Badges will be inserted here -->
96
+ </div>
97
+ </div>
98
+ </div>
99
+
100
+ </div>
101
+
102
+ <script>
103
+ // --- Hallmark Colors (for consistent styling) ---
104
+ const hallmarkList = [ // Ensure this matches the list in app.py
105
+ "activating invasion and metastasis", "avoiding immune destruction",
106
+ "cellular energetics", "enabling replicative immortality",
107
+ "evading growth suppressors", "genomic instability and mutation",
108
+ "inducing angiogenesis", "resisting cell death",
109
+ "sustaining proliferative signaling", "tumor promoting inflammation",
110
+ ];
111
+ const hallmarkColors = {};
112
+ hallmarkList.forEach((hallmark, index) => {
113
+ hallmarkColors[hallmark] = `hallmark-color-${index % 10}`;
114
+ });
115
+
116
+ // --- DOM Elements ---
117
+ const abstractInput = document.getElementById('abstract-input');
118
+ const predictButton = document.getElementById('predict-button');
119
+ const resultsArea = document.getElementById('results-area');
120
+ const loadingIndicator = document.getElementById('loading-indicator');
121
+ const errorMessage = document.getElementById('error-message');
122
+ const predictionOutput = document.getElementById('prediction-output');
123
+
124
+ // --- Event Listener ---
125
+ predictButton.addEventListener('click', handlePrediction);
126
+
127
+ // --- Prediction Logic ---
128
+ async function handlePrediction() {
129
+ const abstractText = abstractInput.value.trim();
130
+ if (!abstractText) {
131
+ alert("Please enter some abstract text.");
132
+ return;
133
+ }
134
+
135
+ // --- UI Updates: Start loading ---
136
+ predictButton.disabled = true;
137
+ resultsArea.classList.remove('hidden');
138
+ predictionOutput.innerHTML = ''; // Clear previous results
139
+ errorMessage.classList.add('hidden');
140
+ loadingIndicator.classList.remove('hidden');
141
+
142
+ try {
143
+ // --- Call the backend API ---
144
+ const response = await fetch('/predict', {
145
+ method: 'POST',
146
+ headers: {
147
+ 'Content-Type': 'application/json',
148
+ },
149
+ body: JSON.stringify({ abstract: abstractText }),
150
+ });
151
+
152
+ // --- Handle Response ---
153
+ loadingIndicator.classList.add('hidden'); // Hide loading indicator
154
+ if (!response.ok) {
155
+ const errorData = await response.json();
156
+ throw new Error(errorData.error || `HTTP error! status: ${response.status}`);
157
+ }
158
+
159
+ const data = await response.json();
160
+ displayPredictions(data.predictions);
161
+
162
+ } catch (error) {
163
+ console.error('Prediction failed:', error);
164
+ loadingIndicator.classList.add('hidden');
165
+ errorMessage.textContent = `Prediction failed: ${error.message}`;
166
+ errorMessage.classList.remove('hidden');
167
+ } finally {
168
+ predictButton.disabled = false; // Re-enable button
169
+ }
170
+ }
171
+
172
+ // --- Display Logic ---
173
+ function displayPredictions(predictions) {
174
+ predictionOutput.innerHTML = ''; // Clear previous just in case
175
+ errorMessage.classList.add('hidden');
176
+
177
+ if (predictions && predictions.length > 0) {
178
+ predictions.forEach(label => {
179
+ const badge = document.createElement('span');
180
+ badge.textContent = label;
181
+ const colorClass = hallmarkColors[label] || 'hallmark-color-9'; // Default color
182
+ badge.className = `hallmark-badge ${colorClass}`;
183
+ predictionOutput.appendChild(badge);
184
+ });
185
+ } else {
186
+ predictionOutput.innerHTML = '<span class="text-gray-500 italic text-sm">No specific hallmarks predicted for this abstract.</span>';
187
+ }
188
+ }
189
+ </script>
190
+ </body>
191
+ </html>