Priyansu19 commited on
Commit
385e83f
·
1 Parent(s): b347ca3

adding finetune code

Browse files
Files changed (1) hide show
  1. bio_gpt_finetune.py +2082 -0
bio_gpt_finetune.py ADDED
@@ -0,0 +1,2082 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """Vizuara BioGPT from Scratch.ipynb
3
+
4
+ Automatically generated by Colab.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/1ys-b99GalAtTE9m7bGwCCACZYv2M8HjO
8
+
9
+ #Vizuara AI Labs: BioGPT Pre-training + Finetuning
10
+
11
+ ## Part 1: Pre-training
12
+
13
+ ### 1.1 Loading the dataset
14
+ """
15
+
16
+ # Colab: Download ~10 GB (uncompressed) of PubMed baseline XML
17
+ import os, re, subprocess, math, requests
18
+ from bs4 import BeautifulSoup
19
+ from urllib.parse import urljoin
20
+
21
+ BASE_URL = "https://ftp.ncbi.nlm.nih.gov/pubmed/baseline/"
22
+ TARGET_UNCOMPRESSED_GB = 1.0
23
+ DEST = "/content/pubmed_xml_subset"
24
+ os.makedirs(DEST, exist_ok=True)
25
+
26
+ # 1) Fetch list of .gz files from the baseline index
27
+ html = requests.get(BASE_URL, timeout=60).text
28
+ soup = BeautifulSoup(html, "html.parser")
29
+
30
+ # All .gz files (e.g., pubmed24n0001.xml.gz)
31
+ hrefs = [a.get("href") for a in soup.find_all("a", href=True)]
32
+ gz_files = sorted([h for h in hrefs if h.endswith(".gz")])
33
+
34
+ print(f"Found {len(gz_files)} .gz files on the baseline index.")
35
+
36
+ # 2) Download sequentially until uncompressed total ≈ target
37
+ def gz_uncompressed_bytes(local_path):
38
+ # Use gzip -l to read uncompressed size from footer (fast; no full decompress)
39
+ out = subprocess.check_output(["gzip", "-l", local_path]).decode()
40
+ # The second line has: compressed uncompressed ratio uncompressed_name
41
+ lines = out.strip().splitlines()
42
+ if len(lines) >= 2:
43
+ parts = re.split(r"\s+", lines[1].strip())
44
+ # parts[1] = uncompressed bytes
45
+ return int(parts[1])
46
+ return 0
47
+
48
+ total_uncompressed = 0
49
+ downloaded = []
50
+
51
+ for fname in gz_files:
52
+ url = urljoin(BASE_URL, fname)
53
+ local = os.path.join(DEST, fname)
54
+ if not os.path.exists(local):
55
+ print(f"→ downloading {fname} ...")
56
+ # quiet, continue on partial, retry a bit
57
+ ret = subprocess.call(["wget", "-q", "-c", "-O", local, url])
58
+ if ret != 0:
59
+ print(f" ! failed: {fname}; skipping")
60
+ if os.path.exists(local): os.remove(local)
61
+ continue
62
+ # read uncompressed size
63
+ try:
64
+ ub = gz_uncompressed_bytes(local)
65
+ total_uncompressed += ub
66
+ downloaded.append((fname, ub))
67
+ print(f" added {fname}: {ub/1e9:.3f} GB uncompressed | total ≈ {total_uncompressed/1e9:.3f} GB")
68
+ except Exception as e:
69
+ print(f" ! could not read size for {fname}: {e}")
70
+
71
+ if total_uncompressed >= TARGET_UNCOMPRESSED_GB * 1e9:
72
+ print("\nTarget reached. Stopping downloads.")
73
+ break
74
+
75
+ print(f"\nDone. Saved {len(downloaded)} files to: {DEST}")
76
+ print(f"Approx. uncompressed total: {total_uncompressed/1e9:.3f} GB")
77
+
78
+ """### 1.2 Converting title and abstract from XML to TXT"""
79
+
80
+ # Colab cell: Parse title + abstract to plain text (one doc/line)
81
+ import os, gzip, glob
82
+ from lxml import etree
83
+ from tqdm import tqdm
84
+
85
+ SRC_DIR = "/content/pubmed_xml_subset" # where your .xml.gz files are
86
+ OUT_DIR = "/content/pubmed_txt" # output folder
87
+ os.makedirs(OUT_DIR, exist_ok=True)
88
+
89
+ train_path = f"{OUT_DIR}/train.txt"
90
+ valid_path = f"{OUT_DIR}/valid.txt"
91
+ test_path = f"{OUT_DIR}/test.txt"
92
+
93
+ # ----- helper: stream-parse one PubMed file -----
94
+ def yield_title_abstract(fp):
95
+ # iterparse to avoid loading whole XML into RAM
96
+ ctx = etree.iterparse(gzip.open(fp), events=("end",), tag="PubmedArticle")
97
+ for _, elem in ctx:
98
+ # Title
99
+ t = elem.find(".//ArticleTitle")
100
+ title = (t.text or "").strip() if t is not None else ""
101
+ # Abstract may have multiple parts <AbstractText>
102
+ abs_nodes = elem.findall(".//AbstractText")
103
+ abs_parts = []
104
+ for a in abs_nodes:
105
+ txt = (a.text or "").strip()
106
+ if txt:
107
+ abs_parts.append(txt)
108
+ abstract = " ".join(abs_parts).strip()
109
+
110
+ if title and abstract:
111
+ text = f"{title}. {abstract}"
112
+ # clean newlines/tabs
113
+ text = " ".join(text.split())
114
+ yield text
115
+
116
+ # free memory
117
+ elem.clear()
118
+ while elem.getprevious() is not None:
119
+ del elem.getparent()[0]
120
+ del ctx
121
+
122
+ # ----- collect and write -----
123
+ gz_files = sorted(glob.glob(os.path.join(SRC_DIR, "*.xml.gz")))
124
+ print(f"Found {len(gz_files)} gz files")
125
+
126
+ # We'll stream all docs, then do a simple split by count.
127
+ all_out = f"{OUT_DIR}/_all.txt"
128
+ with open(all_out, "w", encoding="utf-8") as out:
129
+ for fp in tqdm(gz_files, desc="Parsing"):
130
+ for line in yield_title_abstract(fp):
131
+ out.write(line + "\n")
132
+
133
+ # Quick stats
134
+ num_lines = sum(1 for _ in open(all_out, "r", encoding="utf-8"))
135
+ print("Total docs with title+abstract:", num_lines)
136
+
137
+ # Split 98% / 1% / 1% (adjust if you like)
138
+ train_n = int(num_lines * 0.98)
139
+ valid_n = int(num_lines * 0.01)
140
+ test_n = num_lines - train_n - valid_n
141
+
142
+ with open(all_out, "r", encoding="utf-8") as fin, \
143
+ open(train_path, "w", encoding="utf-8") as ftr, \
144
+ open(valid_path, "w", encoding="utf-8") as fva, \
145
+ open(test_path, "w", encoding="utf-8") as fte:
146
+ for i, line in enumerate(fin):
147
+ if i < train_n: ftr.write(line)
148
+ elif i < train_n + valid_n: fva.write(line)
149
+ else: fte.write(line)
150
+
151
+ print("Wrote:")
152
+ print(" ", train_path)
153
+ print(" ", valid_path)
154
+ print(" ", test_path)
155
+
156
+ # Commented out IPython magic to ensure Python compatibility.
157
+ # Colab cell: Install tools
158
+ !pip -q install sacremoses==0.0.53
159
+ !sudo apt-get -y install g++ >/dev/null
160
+
161
+ # fastBPE (build once)
162
+ !git clone -q https://github.com/glample/fastBPE.git /content/fastBPE
163
+ # %cd /content/fastBPE
164
+ !g++ -std=c++11 -O3 -pthread fastBPE/main.cc -IfastBPE -o fast
165
+ # %cd /content
166
+
167
+ # fairseq (0.12.0 recommended for GPT2-medium arch flag)
168
+ !git clone -q https://github.com/pytorch/fairseq.git /content/fairseq
169
+ # %cd /content/fairseq
170
+ !git checkout v0.12.0 -q
171
+ !pip -q install .
172
+ # %cd /content
173
+
174
+ """### 1.3 Fetch the BioGPT Vocabulary and merged tokens"""
175
+
176
+ # Colab cell: Grab BioGPT bpecodes/dict
177
+ !wget -q -O /content/bpecodes https://raw.githubusercontent.com/microsoft/BioGPT/main/data/BioGPT/bpecodes
178
+ !wget -q -O /content/dict.txt https://raw.githubusercontent.com/microsoft/BioGPT/main/data/BioGPT/dict.txt
179
+ !wc -l /content/dict.txt && head -n 5 /content/dict.txt
180
+
181
+ """### 1.4 Use Moses tokenizer to clean text before applying BPE"""
182
+
183
+ import os
184
+ from sacremoses import MosesTokenizer
185
+ from tqdm.auto import tqdm
186
+
187
+ TXT_DIR = "/content/pubmed_txt"
188
+ BPE_DIR = "/content/pubmed_bpe"
189
+ os.makedirs(BPE_DIR, exist_ok=True)
190
+
191
+ mt = MosesTokenizer(lang="en")
192
+
193
+ def tokenize_file(in_path, out_path, show_progress=True):
194
+ # Count lines once for a nice total
195
+ with open(in_path, "r", encoding="utf-8") as f:
196
+ total = sum(1 for _ in f)
197
+
198
+ with open(in_path, "r", encoding="utf-8") as fin, \
199
+ open(out_path, "w", encoding="utf-8") as fout:
200
+ iterator = fin
201
+ if show_progress:
202
+ iterator = tqdm(fin, total=total, desc=f"Tokenizing {os.path.basename(in_path)}")
203
+ for line in iterator:
204
+ line = line.strip()
205
+ if not line:
206
+ continue
207
+ fout.write(mt.tokenize(line, return_str=True) + "\n")
208
+
209
+ for split in ["train", "valid", "test"]:
210
+ tok = f"{BPE_DIR}/{split}.tok"
211
+ bpe = f"{BPE_DIR}/{split}.bpe"
212
+ tokenize_file(f"{TXT_DIR}/{split}.txt", tok)
213
+
214
+ """### 1.5 Apply BPE to dataset"""
215
+
216
+ # Commented out IPython magic to ensure Python compatibility.
217
+ import os, math, subprocess, numpy as np, shutil
218
+ from tqdm.auto import tqdm
219
+
220
+ BPE_CODES = "/content/bpecodes" # BioGPT bpecodes
221
+ DICT_TXT = "/content/dict.txt" # BioGPT dict
222
+ BPE_DIR = "/content/pubmed_bpe" # where your .tok files are
223
+ BIN_DIR = "/content/pubmed_memmap"
224
+ TMP_DIR = "/content/_bpe_tmp"
225
+ os.makedirs(BIN_DIR, exist_ok=True)
226
+ os.makedirs(TMP_DIR, exist_ok=True)
227
+
228
+ # --- load vocab ---
229
+ token2id = {}
230
+ with open(DICT_TXT, encoding="utf-8") as f:
231
+ for i, line in enumerate(f):
232
+ tok = line.split()[0]
233
+ token2id[tok] = i
234
+ # choose a fallback id ONLY IF we see OOVs later
235
+ fallback_id = token2id.get("</s>", next(iter(token2id.values()))) # prefer EOS, else first token
236
+
237
+ # --- ensure fastBPE binary exists ---
238
+ if not os.path.exists("/content/fastBPE/fast"):
239
+ !git clone -q https://github.com/glample/fastBPE.git /content/fastBPE
240
+ # %cd /content/fastBPE
241
+ !g++ -std=c++11 -O3 -pthread fastBPE/main.cc -IfastBPE -o fast
242
+ # %cd /content
243
+
244
+ def line_count(path):
245
+ c = 0
246
+ with open(path, encoding="utf-8") as f:
247
+ for _ in f:
248
+ c += 1
249
+ return c
250
+
251
+ def apply_bpe_with_progress(tok_file, bpe_file, shards=50):
252
+ total_lines = line_count(tok_file)
253
+ if total_lines == 0:
254
+ open(bpe_file, "w").close()
255
+ return
256
+
257
+ shards = max(1, min(shards, total_lines))
258
+ lines_per = math.ceil(total_lines / shards)
259
+
260
+ split_dir = os.path.join(TMP_DIR, "split")
261
+ out_dir = os.path.join(TMP_DIR, "bpe_parts")
262
+ os.makedirs(split_dir, exist_ok=True)
263
+ os.makedirs(out_dir, exist_ok=True)
264
+
265
+ # 1) split with progress
266
+ with open(tok_file, encoding="utf-8") as fin:
267
+ shard_idx = 0
268
+ line_idx = 0
269
+ fout = None
270
+ pbar = tqdm(total=total_lines, desc=f"Splitting {os.path.basename(tok_file)}")
271
+ for line in fin:
272
+ if line_idx % lines_per == 0:
273
+ if fout: fout.close()
274
+ shard_idx += 1
275
+ fout = open(os.path.join(split_dir, f"part_{shard_idx:05d}.tok"), "w", encoding="utf-8")
276
+ fout.write(line)
277
+ line_idx += 1
278
+ pbar.update(1)
279
+ if fout: fout.close()
280
+ pbar.close()
281
+
282
+ # 2) BPE on each shard with progress
283
+ parts = sorted([p for p in os.listdir(split_dir) if p.endswith(".tok")])
284
+ for p in tqdm(parts, desc="Applying BPE to shards"):
285
+ src = os.path.join(split_dir, p)
286
+ dst = os.path.join(out_dir, p.replace(".tok", ".bpe"))
287
+ subprocess.check_call(["/content/fastBPE/fast", "applybpe", dst, src, BPE_CODES])
288
+
289
+ # 3) concat with progress
290
+ with open(bpe_file, "w", encoding="utf-8") as fout:
291
+ for p in tqdm(parts, desc="Concatenating BPE shards"):
292
+ src = os.path.join(out_dir, p.replace(".tok", ".bpe"))
293
+ with open(src, encoding="utf-8") as fin:
294
+ shutil.copyfileobj(fin, fout)
295
+
296
+ shutil.rmtree(split_dir, ignore_errors=True)
297
+ shutil.rmtree(out_dir, ignore_errors=True)
298
+
299
+ def make_bin(split, dtype=np.uint16, shards=64):
300
+ tok_file = os.path.join(BPE_DIR, f"{split}.tok")
301
+ bpe_file = os.path.join(BPE_DIR, f"{split}.bpe")
302
+
303
+ print(f"\n[{split}] Step 1: Applying BPE merges with progress...")
304
+ apply_bpe_with_progress(tok_file, bpe_file, shards=shards)
305
+
306
+ print(f"[{split}] Step 2: Counting total tokens...")
307
+ total_tokens, total_lines = 0, 0
308
+ with open(bpe_file, encoding="utf-8") as f:
309
+ for line in tqdm(f, desc="Counting tokens"):
310
+ total_tokens += len(line.strip().split())
311
+ total_lines += 1
312
+ print(f"[{split}] Total tokens: {total_tokens:,} | lines: {total_lines:,}")
313
+
314
+ print(f"[{split}] Step 3: Encoding to IDs & writing memmap...")
315
+ bin_path = os.path.join(BIN_DIR, f"{split}.bin")
316
+ arr = np.memmap(bin_path, dtype=dtype, mode="w+", shape=(total_tokens,))
317
+
318
+ idx = 0
319
+ oov_count = 0
320
+ oov_samples = {}
321
+ with open(bpe_file, encoding="utf-8") as f:
322
+ for line in tqdm(f, total=total_lines, desc=f"Encoding {split}"):
323
+ toks = line.strip().split()
324
+ ids = []
325
+ for t in toks:
326
+ if t in token2id:
327
+ ids.append(token2id[t])
328
+ else:
329
+ oov_count += 1
330
+ if len(oov_samples) < 10:
331
+ oov_samples[t] = oov_samples.get(t, 0) + 1
332
+ ids.append(fallback_id) # safe fallback if any OOVs occur
333
+ n = len(ids)
334
+ arr[idx:idx+n] = np.fromiter(ids, dtype=dtype, count=n)
335
+ idx += n
336
+ arr.flush()
337
+
338
+ if oov_count == 0:
339
+ print(f"[{split}] ✅ Saved {bin_path} (no OOVs)")
340
+ else:
341
+ print(f"[{split}] ⚠️ Saved {bin_path} with {oov_count} OOV tokens mapped to id {fallback_id}.")
342
+ print(" First few OOV examples:", list(oov_samples.items()))
343
+
344
+ for split in ["train", "valid", "test"]:
345
+ make_bin(split, dtype=np.uint16, shards=64)
346
+
347
+ """### 1.6 Create input-output pairs"""
348
+
349
+ import os, numpy as np, torch
350
+
351
+ BIN_ROOT = "/content/pubmed_memmap" # where your .bin files are
352
+ DTYPE = np.uint16 # you saved with uint16
353
+
354
+ def get_batch(split):
355
+ fname = "train.bin" if split == "train" else "valid.bin"
356
+ path = os.path.join(BIN_ROOT, fname)
357
+ data = np.memmap(path, dtype=DTYPE, mode='r')
358
+
359
+ ix = torch.randint(len(data) - block_size, (batch_size,))
360
+ x = torch.stack([torch.from_numpy(data[i:i+block_size].astype(np.int64)) for i in ix])
361
+ y = torch.stack([torch.from_numpy(data[i+1:i+1+block_size].astype(np.int64)) for i in ix])
362
+
363
+ if device_type == 'cuda':
364
+ x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
365
+ else:
366
+ x, y = x.to(device), y.to(device)
367
+ return x, y
368
+
369
+ """### 1.7 Define BioGPT architecture"""
370
+
371
+ import torch
372
+ import torch.nn as nn
373
+ import torch.nn.functional as F
374
+ import math
375
+ from dataclasses import dataclass
376
+ import numpy as np
377
+ from tqdm.auto import tqdm
378
+ from contextlib import nullcontext
379
+ import os
380
+
381
+ class LayerNorm(nn.Module):
382
+ def __init__(self, ndim, bias):
383
+ super().__init__()
384
+ self.weight = nn.Parameter(torch.ones(ndim))
385
+ self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
386
+ def forward(self, x):
387
+ return F.layer_norm(x, self.weight.shape, self.weight, self.bias, 1e-5)
388
+
389
+ class CausalSelfAttention(nn.Module):
390
+ def __init__(self, config):
391
+ super().__init__()
392
+ assert config.n_embd % config.n_head == 0
393
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
394
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
395
+ self.attn_dropout = nn.Dropout(config.dropout)
396
+ self.resid_dropout = nn.Dropout(config.dropout)
397
+ self.n_head = config.n_head
398
+ self.n_embd = config.n_embd
399
+ self.flash = hasattr(F, 'scaled_dot_product_attention')
400
+ if not self.flash:
401
+ self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
402
+ .view(1, 1, config.block_size, config.block_size))
403
+
404
+ def forward(self, x):
405
+ B, T, C = x.size()
406
+ q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
407
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
408
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
409
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
410
+
411
+ if self.flash:
412
+ y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.attn_dropout.p if self.training else 0.0, is_causal=True)
413
+ else:
414
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
415
+ att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))
416
+ att = F.softmax(att, dim=-1)
417
+ att = self.attn_dropout(att)
418
+ y = att @ v
419
+
420
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
421
+ y = self.resid_dropout(self.c_proj(y))
422
+ return y
423
+
424
+ class MLP(nn.Module):
425
+ def __init__(self, config):
426
+ super().__init__()
427
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
428
+ self.gelu = nn.GELU()
429
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
430
+ self.dropout = nn.Dropout(config.dropout)
431
+ def forward(self, x):
432
+ return self.dropout(self.c_proj(self.gelu(self.c_fc(x))))
433
+
434
+ class Block(nn.Module):
435
+ def __init__(self, config):
436
+ super().__init__()
437
+ self.ln1 = LayerNorm(config.n_embd, config.bias)
438
+ self.attn = CausalSelfAttention(config)
439
+ self.ln2 = LayerNorm(config.n_embd, config.bias)
440
+ self.mlp = MLP(config)
441
+ def forward(self, x):
442
+ x = x + self.attn(self.ln1(x))
443
+ x = x + self.mlp(self.ln2(x))
444
+ return x
445
+
446
+ @dataclass
447
+ class GPTConfig:
448
+ block_size: int
449
+ vocab_size: int
450
+ n_layer: int
451
+ n_head: int
452
+ n_embd: int
453
+ dropout: float = 0.0
454
+ bias: bool = True
455
+
456
+ class GPT(nn.Module):
457
+ def __init__(self, config):
458
+ super().__init__()
459
+ self.config = config
460
+ self.transformer = nn.ModuleDict(dict(
461
+ wte=nn.Embedding(config.vocab_size, config.n_embd),
462
+ wpe=nn.Embedding(config.block_size, config.n_embd),
463
+ drop=nn.Dropout(config.dropout),
464
+ h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
465
+ ln_f=LayerNorm(config.n_embd, config.bias),
466
+ ))
467
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
468
+ self.transformer.wte.weight = self.lm_head.weight # weight tying
469
+
470
+ self.apply(self._init_weights)
471
+ for pn, p in self.named_parameters():
472
+ if pn.endswith('c_proj.weight'):
473
+ nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer))
474
+
475
+ def _init_weights(self, module):
476
+ if isinstance(module, nn.Linear):
477
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
478
+ if module.bias is not None:
479
+ nn.init.zeros_(module.bias)
480
+ elif isinstance(module, nn.Embedding):
481
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
482
+
483
+ def forward(self, idx, targets=None):
484
+ device = idx.device
485
+ b, t = idx.size()
486
+ assert t <= self.config.block_size
487
+ pos = torch.arange(0, t, dtype=torch.long, device=device)
488
+
489
+ tok_emb = self.transformer.wte(idx)
490
+ pos_emb = self.transformer.wpe(pos)
491
+ x = self.transformer.drop(tok_emb + pos_emb)
492
+ for block in self.transformer.h:
493
+ x = block(x)
494
+ x = self.transformer.ln_f(x)
495
+
496
+ if targets is not None:
497
+ logits = self.lm_head(x)
498
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
499
+ return logits, loss
500
+ else:
501
+ logits = self.lm_head(x[:, [-1], :])
502
+ return logits, None
503
+
504
+ @torch.no_grad()
505
+ def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
506
+ """
507
+ Generate tokens given a conditioning sequence.
508
+ idx: Tensor of shape (B, T)
509
+ """
510
+ for _ in range(max_new_tokens):
511
+ idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
512
+ logits, _ = self(idx_cond)
513
+ logits = logits[:, -1, :] / temperature
514
+ if top_k is not None:
515
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
516
+ logits[logits < v[:, [-1]]] = -float('Inf')
517
+ probs = F.softmax(logits, dim=-1)
518
+ idx_next = torch.multinomial(probs, num_samples=1)
519
+ idx = torch.cat((idx, idx_next), dim=1)
520
+ return idx
521
+
522
+ vocab_size = sum(1 for _ in open("/content/dict.txt", encoding="utf-8"))
523
+ print("Vocab size:", vocab_size) # should be ~42380
524
+
525
+ """### 1.8 Define configuration"""
526
+
527
+ # Pick GPU if available, else CPU
528
+ device = "cuda" if torch.cuda.is_available() else "cpu"
529
+
530
+ # Optional: keep track of the type for AMP autocast
531
+ device_type = 'cuda' if device == 'cuda' else 'cpu'
532
+
533
+ # Now build the config
534
+ vocab_size = sum(1 for _ in open("/content/dict.txt", encoding="utf-8"))
535
+ config = GPTConfig(
536
+ vocab_size=vocab_size,
537
+ block_size=128, # or 1024 for BioGPT-scale training
538
+ n_layer=6, # change to 24 for BioGPT-size
539
+ n_head=6, # change to 16 for BioGPT-size
540
+ n_embd=384, # change to 1024 for BioGPT-size
541
+ dropout=0.1,
542
+ bias=True
543
+ )
544
+
545
+ # Create model and move to device
546
+ model = GPT(config).to(device)
547
+ print("Params (M):", sum(p.numel() for p in model.parameters())/1e6)
548
+
549
+ print(vocab_size)
550
+
551
+ """### 1.9 Define loss function"""
552
+
553
+ def estimate_loss(model):
554
+ out = {}
555
+ model.eval()
556
+ with torch.inference_mode():
557
+ for split in ['train', 'valid']:
558
+ losses = torch.zeros(eval_iters)
559
+ for k in range(eval_iters):
560
+ X, Y = get_batch(split)
561
+ with ctx:
562
+ logits, loss = model(X, Y)
563
+ losses[k] = loss.item()
564
+ out[split] = losses.mean()
565
+ model.train()
566
+ return out
567
+
568
+ """### 1.10 Define the training configuration"""
569
+
570
+ # Training Config
571
+ import torch
572
+ from contextlib import nullcontext
573
+
574
+ learning_rate = 1e-4 #more stable training, earlier 1e-4
575
+ max_iters = 120000 #increase from 25000
576
+ warmup_steps = 1000 #smoother initial train, earlier 100
577
+ min_lr = 5e-4 #lower rate, earlier 5e-4
578
+ eval_iters = 500 # increased from 100
579
+ batch_size = 32 # changed from 16, better gradient estimate
580
+ block_size = 128 #changed from 64, capture longer range dependencies
581
+
582
+ gradient_accumulation_steps = 32 # reduced from 50
583
+
584
+ device = "cuda" if torch.cuda.is_available() else "cpu"
585
+ device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
586
+ # note: float16 data type will automatically use a GradScaler
587
+
588
+ # How to use autocast https://wandb.ai/wandb_fc/tips/reports/How-To-Use-Autocast-in-PyTorch--VmlldzoyMTk4NTky
589
+ #dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler
590
+ dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler
591
+ ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
592
+
593
+ ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
594
+
595
+ torch.set_default_device(device)
596
+ torch.manual_seed(42)
597
+
598
+ """### 1.11 Define optimizers and learning rate"""
599
+
600
+ from torch.optim.lr_scheduler import LinearLR,SequentialLR, CosineAnnealingLR
601
+
602
+ ##PUT IN WEIGHT DECAY, CHANGED BETA2 to 0.95
603
+ optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, betas=(0.9, 0.95), weight_decay=0.1, eps=1e-9) #weight decay for regularization
604
+
605
+ scheduler_warmup = LinearLR(optimizer, total_iters = warmup_steps) #Implement linear warmup
606
+ scheduler_decay = CosineAnnealingLR(optimizer,T_max = max_iters - warmup_steps, eta_min = min_lr) #Implement lr decay
607
+ scheduler = SequentialLR(optimizer, schedulers=[scheduler_warmup, scheduler_decay], milestones=[warmup_steps]) #Switching from warmup to decay
608
+
609
+ # https://stackoverflow.com/questions/72534859/is-gradscaler-necessary-with-mixed-precision-training-with-pytorch
610
+ scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))
611
+
612
+ """### 1.12 Run pre-training!"""
613
+
614
+ best_val_loss = float('inf')
615
+ best_model_params_path = "best_model_params.pt"
616
+ train_loss_list, validation_loss_list = [], []
617
+
618
+ # Ensure model is on the correct device
619
+ model = model.to(device)
620
+
621
+ # In your training loop
622
+ for epoch in tqdm(range(max_iters)):
623
+ if epoch % eval_iters == 0 and epoch != 0:
624
+ # Ensure estimate_loss uses the correct device
625
+ losses = estimate_loss(model)
626
+ print(f"Epoch {epoch}: train loss {losses['train']:.4f}, val loss {losses['valid']:.4f}")
627
+ print(f"The current learning rate: {optimizer.param_groups[0]['lr']:.5f}")
628
+ train_loss_list += [losses['train']]
629
+ validation_loss_list += [losses['valid']]
630
+
631
+ if losses['valid'] < best_val_loss:
632
+ best_val_loss = losses['valid']
633
+ torch.save(model.state_dict(), best_model_params_path)
634
+
635
+ # Ensure X and y are on the correct device
636
+ X, y = get_batch("train")
637
+ X, y = X.to(device), y.to(device)
638
+
639
+ with ctx:
640
+ logits, loss = model(X, y)
641
+ loss = loss / gradient_accumulation_steps
642
+ scaler.scale(loss).backward()
643
+
644
+ if ((epoch + 1) % gradient_accumulation_steps == 0) or (epoch + 1 == max_iters):
645
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)
646
+ scaler.step(optimizer)
647
+ scaler.update()
648
+ optimizer.zero_grad(set_to_none=True)
649
+ scheduler.step()
650
+
651
+ """### 1.13 Plot training and validation losses"""
652
+
653
+ import matplotlib.pyplot as plt
654
+ import numpy as np
655
+
656
+ eval_every = eval_iters # e.g., 500
657
+
658
+ # Convert each tensor to float on CPU
659
+ train_loss_np = [float(t.cpu()) for t in train_loss_list]
660
+ valid_loss_np = [float(t.cpu()) for t in validation_loss_list]
661
+
662
+ steps = np.arange(1, len(train_loss_np) + 1) * eval_every
663
+
664
+ plt.figure(figsize=(6,4))
665
+ plt.plot(steps, train_loss_np, label='train')
666
+ plt.plot(steps, valid_loss_np, label='valid')
667
+ plt.xlabel('Iteration')
668
+ plt.ylabel('Loss')
669
+ plt.title('Pretraining loss')
670
+ plt.legend()
671
+ plt.grid(True, alpha=0.3)
672
+ plt.show()
673
+
674
+ import torch
675
+
676
+ ckpt_path = "best_model_params.pt" # you saved this in the loop
677
+ model.load_state_dict(torch.load(ckpt_path, map_location=device))
678
+ model.eval()
679
+
680
+ """### 1.14 Evaluation on HoC Part 1 (the Hallmarks of Cancers corpus) classification dataset"""
681
+
682
+ import os
683
+ import pandas as pd
684
+ from datasets import load_dataset
685
+ from tqdm.auto import tqdm
686
+
687
+ def download_and_save_hoc_splits(target_dir="/content/hoc"):
688
+ """
689
+ Downloads the bigbio/hallmarks_of_cancer dataset from Hugging Face,
690
+ formats it, and saves it as train.tsv, valid.tsv, and test.tsv
691
+ in the specified directory.
692
+
693
+ Args:
694
+ target_dir (str): The directory to save the .tsv files.
695
+ """
696
+ print("Downloading bigbio/hallmarks_of_cancer dataset...")
697
+ try:
698
+ # Load the dataset splits
699
+ train_data = load_dataset("bigbio/hallmarks_of_cancer", split="train")
700
+ valid_data = load_dataset("bigbio/hallmarks_of_cancer", split="validation")
701
+ test_data = load_dataset("bigbio/hallmarks_of_cancer", split="test")
702
+ print("Dataset downloaded successfully.")
703
+ except Exception as e:
704
+ print(f"Error downloading dataset: {e}")
705
+ print("Please ensure you have internet access and the 'datasets' library is installed (`pip install datasets`).")
706
+ return
707
+
708
+ os.makedirs(target_dir, exist_ok=True)
709
+ print(f"Ensured target directory exists: {target_dir}")
710
+
711
+ splits = {
712
+ "train": train_data,
713
+ "valid": valid_data,
714
+ "test": test_data,
715
+ }
716
+
717
+ for split_name, dataset in splits.items():
718
+ output_path = os.path.join(target_dir, f"{split_name}.tsv")
719
+ print(f"Processing '{split_name}' split and saving to {output_path}...")
720
+
721
+ processed_data = []
722
+ # Iterate with tqdm for progress bar
723
+ for item in tqdm(dataset, desc=f"Processing {split_name}", leave=False):
724
+ text = item.get("text", "")
725
+ labels_list = item.get("labels", [])
726
+
727
+ # Handle the [' none '] case and join the list into a string
728
+ # Using '; ' as a separator, similar to how multi-label strings might appear
729
+ if labels_list == [' none '] or not labels_list:
730
+ label_str = "" # Represent 'none' or empty list as an empty string
731
+ else:
732
+ # Filter out ' none ' if mixed with others, though unlikely based on dataset viewer
733
+ valid_labels = [lbl for lbl in labels_list if lbl.strip().lower() != 'none']
734
+ label_str = "; ".join(valid_labels) # Join valid labels with a separator
735
+
736
+ # Append as a dictionary for easy DataFrame creation later
737
+ # Replace tabs and newlines in text to avoid breaking TSV format
738
+ cleaned_text = " ".join(text.split())
739
+ processed_data.append({"text": cleaned_text, "label": label_str})
740
+
741
+ # Convert to DataFrame and save as TSV
742
+ if processed_data:
743
+ df = pd.DataFrame(processed_data)
744
+ # Ensure columns are in the order expected by load_hoc_tsv heuristic (text, label)
745
+ df = df[["text", "label"]]
746
+ df.to_csv(output_path, sep="\t", index=False, header=False) # Save without index and header
747
+ print(f"Successfully saved {output_path}")
748
+ else:
749
+ print(f"No data processed for split '{split_name}'.")
750
+
751
+ print("\nDataset processing complete.")
752
+
753
+ # Commented out IPython magic to ensure Python compatibility.
754
+ # ===== Zero-shot HoC evaluation for your PRE-TRAINED GPT (with cue + EOS delay) =====
755
+ # Uses your existing GPT / GPTConfig and loads ckpt_path="best_model_params.pt"
756
+
757
+ # installs
758
+ !pip -q install sacremoses==0.0.53 scikit-learn==1.5.1
759
+
760
+ import os, math, difflib, tempfile, subprocess
761
+ import numpy as np
762
+ import pandas as pd
763
+ from tqdm.auto import tqdm
764
+ import torch
765
+ import torch.nn.functional as F
766
+ from sklearn.metrics import precision_recall_fscore_support
767
+ from sacremoses import MosesDetokenizer
768
+
769
+ # ---------- paths ----------
770
+ HOC_DIR = "/content/hoc"
771
+ download_and_save_hoc_splits(HOC_DIR) # train.tsv / valid.tsv / test.tsv live here
772
+ BPE_CODES = "/content/bpecodes" # from BioGPT
773
+ DICT_TXT = "/content/dict.txt" # from BioGPT
774
+ FASTBPE_BIN = "/content/fastBPE/fast" # compiled earlier
775
+ ckpt_path = ckpt_path if 'ckpt_path' in globals() else "best_model_params.pt"
776
+
777
+ os.makedirs(HOC_DIR, exist_ok=True)
778
+
779
+ # ---------- ensure fastBPE + BioGPT codes/dict ----------
780
+ if not os.path.exists(FASTBPE_BIN):
781
+ !git clone -q https://github.com/glample/fastBPE.git /content/fastBPE
782
+ # %cd /content/fastBPE
783
+ !g++ -std=c++11 -O3 -pthread fastBPE/main.cc -IfastBPE -o fast
784
+ # %cd /content
785
+ if not os.path.exists(BPE_CODES):
786
+ !wget -q -O /content/bpecodes https://raw.githubusercontent.com/microsoft/BioGPT/main/data/BioGPT/bpecodes
787
+ if not os.path.exists(DICT_TXT):
788
+ !wget -q -O /content/dict.txt https://raw.githubusercontent.com/microsoft/BioGPT/main/data/BioGPT/dict.txt
789
+
790
+ # ---------- vocab maps ----------
791
+ token2id, id2token = {}, {}
792
+ with open(DICT_TXT, encoding="utf-8") as f:
793
+ for i, line in enumerate(f):
794
+ tok = line.split()[0]
795
+ token2id[tok] = i
796
+ id2token[i] = tok
797
+
798
+ eos_id = token2id.get("</s>", 0)
799
+ pad_id = eos_id # safe pad; loss is masked anyway
800
+
801
+ # ---------- BPE helpers ----------
802
+ def bpe_encode_lines(lines, shard_size=2000, desc="BPE"):
803
+ if len(lines) == 0:
804
+ return []
805
+ out = []
806
+ with tempfile.TemporaryDirectory() as td:
807
+ for start in tqdm(range(0, len(lines), shard_size), desc=f"{desc} ({len(lines)} lines)", leave=False):
808
+ chunk = lines[start:start+shard_size]
809
+ src = os.path.join(td, f"src_{start}.txt")
810
+ dst = os.path.join(td, f"dst_{start}.bpe")
811
+ with open(src, "w", encoding="utf-8") as w:
812
+ for s in chunk: w.write((s or "").strip() + "\n")
813
+ subprocess.check_call([FASTBPE_BIN, "applybpe", dst, src, BPE_CODES])
814
+ with open(dst, "r", encoding="utf-8") as r:
815
+ for line in r:
816
+ out.append(line.strip().split())
817
+ return out
818
+
819
+ def tokens_to_ids(bpe_tokens):
820
+ ids = []
821
+ for t in bpe_tokens:
822
+ ids.append(token2id.get(t, pad_id))
823
+ return ids, 0
824
+
825
+ def bpe_decode_tokens(bpe_tokens):
826
+ s = ' '.join(bpe_tokens).replace('@@ ', '')
827
+ return MosesDetokenizer(lang='en').detokenize(s.split())
828
+
829
+ # ---------- load HoC test ----------
830
+ def load_hoc_tsv(path):
831
+ df = pd.read_csv(path, sep="\t", header=None, dtype=str).fillna("")
832
+ assert df.shape[1] == 2, f"{path} must have 2 columns"
833
+ avg0, avg1 = df[0].astype(str).str.len().mean(), df[1].astype(str).str.len().mean()
834
+ df.columns = ["text","label"] if avg0 > avg1 else ["label","text"]
835
+ return df
836
+
837
+ test_path = os.path.join(HOC_DIR, "test.tsv")
838
+ assert os.path.exists(test_path), f"Missing {test_path}"
839
+ test_df = load_hoc_tsv(test_path)
840
+ print("Test size:", len(test_df))
841
+
842
+ # ---------- the 10 Hallmarks (no 'empty') ----------
843
+ HALLMARKS = [
844
+ "activating invasion and metastasis",
845
+ "avoiding immune destruction",
846
+ "cellular energetics",
847
+ "enabling replicative immortality",
848
+ "evading growth suppressors",
849
+ "genomic instability and mutation",
850
+ "inducing angiogenesis",
851
+ "resisting cell death",
852
+ "sustaining proliferative signaling",
853
+ "tumor promoting inflammation",
854
+ ]
855
+
856
+ def split_labels(s: str):
857
+ s = (s or "").strip()
858
+ if not s: return []
859
+ for sep in [",",";","|"]:
860
+ if sep in s:
861
+ return [p.strip() for p in s.split(sep) if p.strip()]
862
+ return [s]
863
+
864
+ def normalize_labels(labs):
865
+ keep, low = [], [L.lower() for L in HALLMARKS]
866
+ for x in labs:
867
+ xl = x.lower().strip()
868
+ if xl in low:
869
+ keep.append(HALLMARKS[low.index(xl)])
870
+ else:
871
+ best = difflib.get_close_matches(xl, low, n=1, cutoff=0.7)
872
+ if best:
873
+ keep.append(HALLMARKS[low.index(best[0])])
874
+ return sorted(dict.fromkeys(keep))
875
+
876
+ # ---------- Build allowed-token mask (labels + separators + </s>) & first-step forbids ----------
877
+ def build_allowed_mask_and_first_forbid(vocab_size, device):
878
+ allowed = set()
879
+ sep_ids = set()
880
+ # Hallmark tokens (all tokens that appear in these strings)
881
+ for bpe in bpe_encode_lines(HALLMARKS, desc="BPE hallmarks"):
882
+ ids, _ = tokens_to_ids(bpe); allowed.update(ids)
883
+ # Separators; we also record their token ids to block at the first step
884
+ SEPS = [", ", ",", "; ", ";", "|", " and "]
885
+ for sep in SEPS:
886
+ bpe = bpe_encode_lines([sep], desc="BPE seps")[0]
887
+ ids, _ = tokens_to_ids(bpe)
888
+ allowed.update(ids)
889
+ sep_ids.update(ids)
890
+ allowed.add(eos_id)
891
+
892
+ mask = torch.full((vocab_size,), float('-inf'), device=device)
893
+ mask[list(allowed)] = 0.0
894
+ first_forbid = torch.zeros((vocab_size,), dtype=torch.bool, device=device)
895
+ first_forbid[list(sep_ids)] = True
896
+ first_forbid[eos_id] = True # never allow EOS as the first generated token
897
+ return mask, first_forbid
898
+
899
+ device = "cuda" if torch.cuda.is_available() else "cpu"
900
+ ALLOWED_MASK, FIRST_STEP_FORBID = build_allowed_mask_and_first_forbid(len(token2id), device)
901
+
902
+ # ---------- Build contexts (text </s> + textual cue) ----------
903
+ PROMPT_TEXT = " hallmarks of cancer:" # small cue after abstract
904
+ PROMPT_BPE = bpe_encode_lines([PROMPT_TEXT], desc="BPE prompt")[0]
905
+ PROMPT_IDS, _ = tokens_to_ids(PROMPT_BPE)
906
+
907
+ def make_context_with_prompt(df):
908
+ texts = df["text"].astype(str).tolist()
909
+ bpes = bpe_encode_lines(texts, desc="BPE test ctx")
910
+ ctx = []
911
+ for bpe in bpes:
912
+ ids, _ = tokens_to_ids(bpe)
913
+ ctx.append(np.array(ids + [eos_id] + PROMPT_IDS, dtype=np.int64))
914
+ return ctx
915
+
916
+ def pad_batch(seqs):
917
+ L = max(len(s) for s in seqs)
918
+ out = np.full((len(seqs), L), pad_id, dtype=np.int64)
919
+ for i, s in enumerate(seqs):
920
+ out[i, :len(s)] = s
921
+ return torch.from_numpy(out)
922
+
923
+ def ids_to_tokens(ids):
924
+ return [id2token.get(int(i), "<unk>") for i in ids]
925
+
926
+ def to_canonical(pred_chunk: str):
927
+ s = (pred_chunk or "").strip().lower()
928
+ low = [L.lower() for L in HALLMARKS]
929
+ if s in low: return HALLMARKS[low.index(s)]
930
+ best = difflib.get_close_matches(s, low, n=1, cutoff=0.7)
931
+ return HALLMARKS[low.index(best[0])] if best else None
932
+
933
+ # ---------- Require your GPT & GPTConfig from pretraining ----------
934
+ assert 'GPT' in globals(), "Please define your GPT class (same as pretraining) before running this cell."
935
+ assert 'GPTConfig' in globals(), "Please ensure GPTConfig is defined."
936
+
937
+ cfg = GPTConfig(
938
+ vocab_size=len(token2id),
939
+ block_size=(config.block_size if 'config' in globals() else 128),
940
+ n_layer=(config.n_layer if 'config' in globals() else 6),
941
+ n_head=(config.n_head if 'config' in globals() else 6),
942
+ n_embd=(config.n_embd if 'config' in globals() else 384),
943
+ dropout=(config.dropout if 'config' in globals() else 0.1),
944
+ bias=(config.bias if 'config' in globals() else True),
945
+ )
946
+ base = GPT(cfg).to(device)
947
+
948
+ # safe WPE resize when loading the checkpoint
949
+ def load_with_wpe_resize(model, ckpt_path):
950
+ sd = torch.load(ckpt_path, map_location="cpu")
951
+ key = "transformer.wpe.weight"
952
+ if key in sd:
953
+ old = sd[key]
954
+ new_w = model.transformer.wpe.weight
955
+ new_len = new_w.shape[0]
956
+ if old.shape[0] != new_len:
957
+ new = new_w.data.clone()
958
+ n = min(new_len, old.shape[0])
959
+ new[:n] = old[:n]
960
+ if new_len > n:
961
+ torch.nn.init.normal_(new[n:], mean=0.0, std=0.02)
962
+ sd[key] = new
963
+ missing, unexpected = base.load_state_dict(sd, strict=False)
964
+ if missing or unexpected:
965
+ print("Missing keys:", missing)
966
+ print("Loaded PRETRAINED checkpoint:", ckpt_path)
967
+
968
+ assert os.path.exists(ckpt_path), f"Checkpoint not found: {ckpt_path}"
969
+ load_with_wpe_resize(base, ckpt_path)
970
+ base.eval()
971
+
972
+ # ---------- Constrained greedy decode with cue + EOS delay ----------
973
+ @torch.no_grad()
974
+ def gpt_generate_with_cue(model, idx, allowed_mask, first_step_forbid,
975
+ max_new_tokens=24, min_new_before_eos=2, eos_penalty=-2.0, temperature=0.0):
976
+ """
977
+ - Restrict vocabulary with `allowed_mask`
978
+ - For the very first generated token, forbid separators + EOS
979
+ - For the first `min_new_before_eos` tokens, disallow EOS entirely
980
+ - After that, add a small penalty to EOS (so it doesn't end too early)
981
+ """
982
+ out = idx.clone()
983
+ B = out.size(0)
984
+ finished = torch.zeros(B, dtype=torch.bool, device=out.device)
985
+ steps = 0
986
+ for _ in range(max_new_tokens):
987
+ ctx = out[:, -model.config.block_size:]
988
+ logits, _ = model(ctx) # (B,1,V)
989
+ logits = logits[:, -1, :] # (B,V)
990
+
991
+ # restrict to label vocab
992
+ logits = logits + allowed_mask
993
+
994
+ # first token: block separators + EOS
995
+ if steps == 0:
996
+ logits[:, first_step_forbid] = -1e9
997
+
998
+ # delay EOS for a couple steps, then mildly penalize
999
+ if steps < min_new_before_eos:
1000
+ logits[:, eos_id] = -1e9
1001
+ else:
1002
+ logits[:, eos_id] += eos_penalty
1003
+
1004
+ # pick next
1005
+ if temperature <= 0:
1006
+ next_id = torch.argmax(logits, dim=-1)
1007
+ else:
1008
+ probs = F.softmax(logits / temperature, dim=-1)
1009
+ next_id = torch.multinomial(probs, num_samples=1).squeeze(1)
1010
+
1011
+ next_id = next_id.masked_fill(finished, eos_id)
1012
+ out = torch.cat([out, next_id.unsqueeze(1)], dim=1)
1013
+ finished |= (next_id == eos_id)
1014
+ steps += 1
1015
+ if bool(finished.all()):
1016
+ break
1017
+ return out[:, idx.size(1):]
1018
+
1019
+ @torch.no_grad()
1020
+ def predict_labels_for_batch_generative(xb):
1021
+ gens = gpt_generate_with_cue(
1022
+ base, xb, allowed_mask=ALLOWED_MASK, first_step_forbid=FIRST_STEP_FORBID,
1023
+ max_new_tokens=24, min_new_before_eos=2, eos_penalty=-2.0, temperature=0.0
1024
+ )
1025
+ preds = []
1026
+ for g in gens:
1027
+ toks = ids_to_tokens(g.detach().cpu().numpy())
1028
+ toks = toks[: toks.index("</s>")] if "</s>" in toks else toks
1029
+ label_str = bpe_decode_tokens(toks).strip().lower()
1030
+
1031
+ parts = []
1032
+ for sep in [",",";","|"]:
1033
+ if sep in label_str:
1034
+ parts = [p.strip() for p in label_str.split(sep) if p.strip()]
1035
+ break
1036
+ if not parts:
1037
+ parts = [label_str] if label_str else []
1038
+
1039
+ mapped = []
1040
+ for p in parts:
1041
+ can = to_canonical(p)
1042
+ if can and can not in mapped:
1043
+ mapped.append(can)
1044
+ preds.append(mapped) # may be []
1045
+ return preds
1046
+
1047
+ # ---------- Run decoding on TEST ----------
1048
+ ctx_test = make_context_with_prompt(test_df)
1049
+ preds_all = []
1050
+ B = 32
1051
+ for i in tqdm(range(0, len(ctx_test), B), desc="Decoding (pretrain+cue, test)"):
1052
+ xb = pad_batch(ctx_test[i:i+B]).to(device)
1053
+ preds_all.extend(predict_labels_for_batch_generative(xb))
1054
+
1055
+ # ---------- Ground truth & metrics (10 hallmarks only) ----------
1056
+ y_true = [ normalize_labels(split_labels(s)) for s in test_df["label"].astype(str).tolist() ]
1057
+ LABELS = HALLMARKS
1058
+ LIDX = {l:i for i,l in enumerate(LABELS)}
1059
+ def binarize(labs):
1060
+ v = [0]*len(LABELS)
1061
+ for l in labs:
1062
+ if l in LIDX: v[LIDX[l]] = 1
1063
+ return v
1064
+
1065
+ Y_true = np.array([binarize(l) for l in y_true], dtype=np.int64)
1066
+ Y_pred = np.array([binarize(l) for l in preds_all], dtype=np.int64)
1067
+
1068
+ micro_p, micro_r, micro_f1, _ = precision_recall_fscore_support(Y_true, Y_pred, average='micro', zero_division=0)
1069
+ macro_p, macro_r, macro_f1, _ = precision_recall_fscore_support(Y_true, Y_pred, average='macro', zero_division=0)
1070
+
1071
+ print(f"\n[PRETRAIN+cue] HALLMARKS-ONLY Micro P/R/F1: {micro_p:.4f} / {micro_r:.4f} / {micro_f1:.4f}")
1072
+ print( f"[PRETRAIN+cue] HALLMARKS-ONLY Macro P/R/F1: {macro_p:.4f} / {macro_r:.4f} / {macro_f1:.4f}")
1073
+
1074
+ perclass = precision_recall_fscore_support(Y_true, Y_pred, average=None, zero_division=0)
1075
+ per_df_pre = pd.DataFrame({
1076
+ "label": LABELS,
1077
+ "precision": perclass[0],
1078
+ "recall": perclass[1],
1079
+ "f1": perclass[2],
1080
+ "support": perclass[3],
1081
+ }).sort_values("label")
1082
+
1083
+ print("\nPer-class results (PRETRAIN+cue, 10 hallmarks):")
1084
+ print(per_df_pre.to_string(index=False))
1085
+
1086
+ per_df_pre.to_csv("hoc_test_results_pretrain_cue.csv", index=False)
1087
+ print("Saved: hoc_test_results_pretrain_cue.csv")
1088
+
1089
+ # (optional) exclude empty-label rows from eval:
1090
+ # mask = (Y_true.sum(axis=1) > 0)
1091
+ # ... recompute scores on Y_true[mask], Y_pred[mask]
1092
+
1093
+ """### 1.15 Evaluation on HoC Part 2 (the Hallmarks of Cancers corpus) classification dataset"""
1094
+
1095
+ # === Show 10 "questions" (abstract + prompt) and the model's answers (pretrained+cue) ===
1096
+ import os, difflib, numpy as np, pandas as pd, torch, torch.nn.functional as F
1097
+ from tqdm.auto import tqdm
1098
+ from sklearn.metrics import precision_recall_fscore_support
1099
+
1100
+ # ---- Assumptions / fallbacks ----
1101
+ HOC_DIR = globals().get("HOC_DIR", "/content/hoc")
1102
+ ckpt_path = globals().get("ckpt_path", "best_model_params.pt")
1103
+ device = "cuda" if torch.cuda.is_available() else "cpu"
1104
+
1105
+ # Hallmarks (10 classes, no "empty")
1106
+ HALLMARKS = [
1107
+ "activating invasion and metastasis",
1108
+ "avoiding immune destruction",
1109
+ "cellular energetics",
1110
+ "enabling replicative immortality",
1111
+ "evading growth suppressors",
1112
+ "genomic instability and mutation",
1113
+ "inducing angiogenesis",
1114
+ "resisting cell death",
1115
+ "sustaining proliferative signaling",
1116
+ "tumor promoting inflammation",
1117
+ ]
1118
+
1119
+ # ---------- Helper fallbacks if not defined earlier ----------
1120
+ def _need(name): return name not in globals()
1121
+
1122
+ # TSV loader
1123
+ if _need("load_hoc_tsv"):
1124
+ def load_hoc_tsv(path):
1125
+ df = pd.read_csv(path, sep="\t", header=None, dtype=str).fillna("")
1126
+ assert df.shape[1] == 2, f"{path} must have 2 columns"
1127
+ avg0, avg1 = df[0].astype(str).str.len().mean(), df[1].astype(str).str.len().mean()
1128
+ df.columns = ["text","label"] if avg0 > avg1 else ["label","text"]
1129
+ return df
1130
+
1131
+ # If test_df not in memory, load it
1132
+ if "test_df" not in globals():
1133
+ test_df = load_hoc_tsv(os.path.join(HOC_DIR, "test.tsv"))
1134
+
1135
+ # Simple label split/normalization utilities
1136
+ def split_labels(s: str):
1137
+ s = (s or "").strip()
1138
+ if not s: return []
1139
+ for sep in [",",";","|"]:
1140
+ if sep in s:
1141
+ return [p.strip() for p in s.split(sep) if p.strip()]
1142
+ return [s]
1143
+
1144
+ def normalize_labels(labs):
1145
+ keep, low = [], [L.lower() for L in HALLMARKS]
1146
+ for x in labs:
1147
+ xl = x.lower().strip()
1148
+ if xl in low:
1149
+ keep.append(HALLMARKS[low.index(xl)])
1150
+ else:
1151
+ best = difflib.get_close_matches(xl, low, n=1, cutoff=0.7)
1152
+ if best:
1153
+ keep.append(HALLMARKS[low.index(best[0])])
1154
+ # de-dup & stable order
1155
+ seen, out = set(), []
1156
+ for k in keep:
1157
+ if k not in seen:
1158
+ seen.add(k); out.append(k)
1159
+ return out
1160
+
1161
+ # BPE helpers (must exist: token2id, id2token, bpe_encode_lines, tokens_to_ids, bpe_decode_tokens, eos_id, pad_id)
1162
+ for req in ["token2id","id2token","bpe_encode_lines","tokens_to_ids","bpe_decode_tokens","eos_id","pad_id"]:
1163
+ assert req in globals(), f"Missing `{req}` — run the setup cell that defines dict/bpecodes and BPE helpers."
1164
+
1165
+ # Build allowed-token mask & first-step forbids if not present
1166
+ if _need("ALLOWED_MASK") or _need("FIRST_STEP_FORBID"):
1167
+ def build_allowed_mask_and_first_forbid(vocab_size, device):
1168
+ allowed = set(); sep_ids = set()
1169
+ # all tokens that appear in hallmark strings
1170
+ for bpe in bpe_encode_lines(HALLMARKS, desc="BPE hallmarks"):
1171
+ ids, _ = tokens_to_ids(bpe); allowed.update(ids)
1172
+ # separators (also block them on very first generated step)
1173
+ SEPS = [", ", ",", "; ", ";", "|", " and "]
1174
+ for sep in SEPS:
1175
+ bpe = bpe_encode_lines([sep], desc="BPE seps")[0]
1176
+ ids, _ = tokens_to_ids(bpe); allowed.update(ids); sep_ids.update(ids)
1177
+ allowed.add(eos_id)
1178
+ mask = torch.full((vocab_size,), float('-inf'), device=device)
1179
+ mask[list(allowed)] = 0.0
1180
+ first_forbid = torch.zeros((vocab_size,), dtype=torch.bool, device=device)
1181
+ first_forbid[list(sep_ids)] = True
1182
+ first_forbid[eos_id] = True
1183
+ return mask, first_forbid
1184
+ ALLOWED_MASK, FIRST_STEP_FORBID = build_allowed_mask_and_first_forbid(len(token2id), device)
1185
+
1186
+ # Prompt (the "question" cue)
1187
+ PROMPT_TEXT = " hallmarks of cancer:"
1188
+ PROMPT_BPE = bpe_encode_lines([PROMPT_TEXT], desc="BPE prompt")[0]
1189
+ PROMPT_IDS, _ = tokens_to_ids(PROMPT_BPE)
1190
+
1191
+ # Build contexts with prompt
1192
+ def make_context_with_prompt(rows):
1193
+ bpes = bpe_encode_lines(rows["text"].astype(str).tolist(), desc="BPE ctx (sample)")
1194
+ ctx = []
1195
+ for bpe in bpes:
1196
+ ids, _ = tokens_to_ids(bpe)
1197
+ ctx.append(np.array(ids + [eos_id] + PROMPT_IDS, dtype=np.int64))
1198
+ return ctx
1199
+
1200
+ def pad_batch(seqs):
1201
+ L = max(len(s) for s in seqs)
1202
+ out = np.full((len(seqs), L), pad_id, dtype=np.int64)
1203
+ for i, s in enumerate(seqs):
1204
+ out[i, :len(s)] = s
1205
+ return torch.from_numpy(out)
1206
+
1207
+ def ids_to_tokens(ids):
1208
+ return [id2token.get(int(i), "<unk>") for i in ids]
1209
+
1210
+ def to_canonical(pred_chunk: str):
1211
+ s = (pred_chunk or "").strip().lower()
1212
+ low = [L.lower() for L in HALLMARKS]
1213
+ if s in low: return HALLMARKS[low.index(s)]
1214
+ best = difflib.get_close_matches(s, low, n=1, cutoff=0.7)
1215
+ return HALLMARKS[low.index(best[0])] if best else None
1216
+
1217
+ # If the pretrained model (`base`) isn’t loaded yet, load it
1218
+ if _need("base"):
1219
+ assert 'GPT' in globals() and 'GPTConfig' in globals(), "Define GPT and GPTConfig first (your pretraining classes)."
1220
+ assert os.path.exists(ckpt_path), f"Checkpoint not found: {ckpt_path}"
1221
+ cfg = GPTConfig(
1222
+ vocab_size=len(token2id),
1223
+ block_size=(config.block_size if 'config' in globals() else 128),
1224
+ n_layer=(config.n_layer if 'config' in globals() else 6),
1225
+ n_head=(config.n_head if 'config' in globals() else 6),
1226
+ n_embd=(config.n_embd if 'config' in globals() else 384),
1227
+ dropout=(config.dropout if 'config' in globals() else 0.1),
1228
+ bias=(config.bias if 'config' in globals() else True),
1229
+ )
1230
+ base = GPT(cfg).to(device)
1231
+ # safe WPE resize
1232
+ def load_with_wpe_resize(model, path):
1233
+ sd = torch.load(path, map_location="cpu")
1234
+ key = "transformer.wpe.weight"
1235
+ if key in sd:
1236
+ old = sd[key]
1237
+ new_w = model.transformer.wpe.weight
1238
+ new_len = new_w.shape[0]
1239
+ if old.shape[0] != new_len:
1240
+ new = new_w.data.clone()
1241
+ n = min(new_len, old.shape[0])
1242
+ new[:n] = old[:n]
1243
+ if new_len > n:
1244
+ torch.nn.init.normal_(new[n:], mean=0.0, std=0.02)
1245
+ sd[key] = new
1246
+ model.load_state_dict(sd, strict=False)
1247
+ load_with_wpe_resize(base, ckpt_path)
1248
+ base.eval()
1249
+
1250
+ # Constrained generation with cue + EOS delay (define if missing)
1251
+ if _need("gpt_generate_with_cue"):
1252
+ @torch.no_grad()
1253
+ def gpt_generate_with_cue(model, idx, allowed_mask, first_step_forbid,
1254
+ max_new_tokens=24, min_new_before_eos=2, eos_penalty=-2.0, temperature=0.0):
1255
+ out = idx.clone()
1256
+ B = out.size(0)
1257
+ finished = torch.zeros(B, dtype=torch.bool, device=out.device)
1258
+ steps = 0
1259
+ for _ in range(max_new_tokens):
1260
+ ctx = out[:, -model.config.block_size:]
1261
+ logits, _ = model(ctx) # (B,1,V)
1262
+ logits = logits[:, -1, :] # (B,V)
1263
+ logits = logits + allowed_mask # restrict vocab
1264
+ if steps == 0:
1265
+ logits[:, first_step_forbid] = -1e9
1266
+ if steps < min_new_before_eos:
1267
+ logits[:, eos_id] = -1e9
1268
+ else:
1269
+ logits[:, eos_id] += eos_penalty
1270
+ if temperature <= 0:
1271
+ next_id = torch.argmax(logits, dim=-1)
1272
+ else:
1273
+ probs = F.softmax(logits / temperature, dim=-1)
1274
+ next_id = torch.multinomial(probs, num_samples=1).squeeze(1)
1275
+ next_id = next_id.masked_fill(finished, eos_id)
1276
+ out = torch.cat([out, next_id.unsqueeze(1)], dim=1)
1277
+ finished |= (next_id == eos_id)
1278
+ steps += 1
1279
+ if bool(finished.all()):
1280
+ break
1281
+ return out[:, idx.size(1):]
1282
+
1283
+ # ---------- Sample 10 and print Q&A ----------
1284
+ SAMPLE_N = 10
1285
+ sample = test_df.sample(n=min(SAMPLE_N, len(test_df)), random_state=42).reset_index(drop=True)
1286
+
1287
+ # prepare contexts
1288
+ ctx = make_context_with_prompt(sample)
1289
+ B = 10 # single batch is fine here
1290
+ xb = pad_batch(ctx).to(device)
1291
+
1292
+ # generate
1293
+ gens = gpt_generate_with_cue(
1294
+ base, xb, allowed_mask=ALLOWED_MASK, first_step_forbid=FIRST_STEP_FORBID,
1295
+ max_new_tokens=24, min_new_before_eos=2, eos_penalty=-2.0, temperature=0.0
1296
+ )
1297
+
1298
+ # decode + print
1299
+ for i, g in enumerate(gens):
1300
+ text = sample.loc[i, "text"]
1301
+ gold = normalize_labels(split_labels(sample.loc[i, "label"]))
1302
+
1303
+ toks = ids_to_tokens(g.detach().cpu().numpy())
1304
+ toks = toks[: toks.index("</s>")] if "</s>" in toks else toks
1305
+ raw = ' '.join(toks).replace('@@ ', '').strip().lower()
1306
+
1307
+ # split raw into parts and map to canonical labels
1308
+ parts = []
1309
+ for sep in [",",";","|"]:
1310
+ if sep in raw:
1311
+ parts = [p.strip() for p in raw.split(sep) if p.strip()]
1312
+ break
1313
+ if not parts:
1314
+ parts = [raw] if raw else []
1315
+ pred = []
1316
+ for p in parts:
1317
+ can = to_canonical(p)
1318
+ if can and can not in pred:
1319
+ pred.append(can)
1320
+
1321
+ print(f"\n=== Example {i+1} ===")
1322
+ print("QUESTION:")
1323
+ print("Abstract:", (text.replace("\n"," ")[:350] + ("..." if len(text) > 350 else "")))
1324
+ print("Prompt: hallmarks of cancer:")
1325
+ print("GOLD: ", gold if gold else "[]")
1326
+ print("ANSWER: ", pred if pred else "[]")
1327
+ print("Raw gen:", raw if raw else "<empty>")
1328
+
1329
+ """## Part 2: Finetuning
1330
+
1331
+ ### 2.1 Setup: paths + installs
1332
+ """
1333
+
1334
+ # Commented out IPython magic to ensure Python compatibility.
1335
+ # --- Setup: paths + installs (run once) ---
1336
+ !pip -q install sacremoses==0.0.53 scikit-learn==1.5.1
1337
+
1338
+ import os, subprocess, json, math, random, difflib, tempfile, shutil
1339
+ from pathlib import Path
1340
+ import numpy as np
1341
+ import pandas as pd
1342
+ from collections import Counter, defaultdict
1343
+
1344
+ import torch, torch.nn as nn, torch.nn.functional as F
1345
+ from torch.utils.data import Dataset, DataLoader
1346
+ from torch.optim.lr_scheduler import LinearLR, SequentialLR, CosineAnnealingLR
1347
+ from sacremoses import MosesDetokenizer
1348
+ from tqdm.auto import tqdm # <-- used in BPE w/ progress
1349
+
1350
+ # ---- paths ----
1351
+ HOC_DIR = "/content/hoc" # << put your train/valid/test.tsv here
1352
+ BPE_CODES = "/content/bpecodes" # from your pre-training cell
1353
+ DICT_TXT = "/content/dict.txt" # from your pre-training cell
1354
+ FASTBPE = "/content/fastBPE/fast" # compiled earlier in your notebook
1355
+
1356
+ os.makedirs(HOC_DIR, exist_ok=True)
1357
+
1358
+ # Ensure fastBPE exists (rebuild if needed)
1359
+ if not os.path.exists(FASTBPE):
1360
+ !git clone -q https://github.com/glample/fastBPE.git /content/fastBPE
1361
+ # %cd /content/fastBPE
1362
+ !g++ -std=c++11 -O3 -pthread fastBPE/main.cc -IfastBPE -o fast
1363
+ # %cd /content
1364
+
1365
+ # ---- load BioGPT dictionary ----
1366
+ token2id = {}
1367
+ id2token = {}
1368
+ with open(DICT_TXT, encoding="utf-8") as f:
1369
+ for i, line in enumerate(f):
1370
+ tok = line.split()[0]
1371
+ token2id[tok] = i
1372
+ id2token[i] = tok
1373
+
1374
+ # pick special ids
1375
+ eos_id = token2id.get("</s>", 0)
1376
+ pad_id = eos_id # safe padding with eos for inputs; we mask loss anyway
1377
+
1378
+ # ---- BPE encode/decode helpers (fastBPE uses '@@' continuation) ----
1379
+ def bpe_encode_lines(lines, shard_size=2000, desc="BPE"):
1380
+ """
1381
+ Progress-enabled BPE encoding using fastBPE, processing in shards.
1382
+ Returns: list[list[str]] (BPE tokens per line)
1383
+ """
1384
+ if len(lines) == 0:
1385
+ return []
1386
+ out_tokens = []
1387
+ with tempfile.TemporaryDirectory() as td:
1388
+ for start in tqdm(range(0, len(lines), shard_size), desc=f"{desc} ({len(lines)} lines)", leave=False):
1389
+ chunk = lines[start:start+shard_size]
1390
+ src = os.path.join(td, f"src_{start}.txt")
1391
+ dst = os.path.join(td, f"dst_{start}.bpe")
1392
+ with open(src, "w", encoding="utf-8") as f:
1393
+ for s in chunk:
1394
+ f.write((s or "").strip() + "\n")
1395
+ subprocess.check_call([FASTBPE, "applybpe", dst, src, BPE_CODES])
1396
+ with open(dst, "r", encoding="utf-8") as f:
1397
+ for line in f:
1398
+ out_tokens.append(line.strip().split())
1399
+ return out_tokens
1400
+
1401
+ def bpe_decode_tokens(bpe_tokens):
1402
+ """Merge '@@' continuations and detokenize to plain text (for label decoding)."""
1403
+ s = ' '.join(bpe_tokens).replace('@@ ', '')
1404
+ md = MosesDetokenizer(lang='en')
1405
+ return md.detokenize(s.split())
1406
+
1407
+ def tokens_to_ids(bpe_tokens):
1408
+ ids = []
1409
+ oov = 0
1410
+ for t in bpe_tokens:
1411
+ if t in token2id:
1412
+ ids.append(token2id[t])
1413
+ else:
1414
+ ids.append(pad_id) # unlikely, but safe fallback
1415
+ oov += 1
1416
+ return ids, oov
1417
+
1418
+ """### 2.2 Load HoC dataset and map targets to labels"""
1419
+
1420
+ # --- Load HoC TSVs (2 columns, no header). Heuristically figure out which is text vs label. ---
1421
+ def load_hoc_tsv(path):
1422
+ df = pd.read_csv(path, sep="\t", header=None, dtype=str).fillna("")
1423
+ assert df.shape[1] == 2, f"Expected 2 columns in {path}, got {df.shape}"
1424
+ avg0, avg1 = df[0].astype(str).str.len().mean(), df[1].astype(str).str.len().mean()
1425
+ if avg0 > avg1:
1426
+ df.columns = ["text", "label"]
1427
+ else:
1428
+ df.columns = ["label", "text"]
1429
+ return df
1430
+
1431
+ train_df = load_hoc_tsv(f"{HOC_DIR}/train.tsv")
1432
+ valid_df = load_hoc_tsv(f"{HOC_DIR}/valid.tsv")
1433
+ test_df = load_hoc_tsv(f"{HOC_DIR}/test.tsv")
1434
+
1435
+ print("Splits:", len(train_df), len(valid_df), len(test_df))
1436
+
1437
+ # --- Hallmarks (10 classes; we ignore 'empty' for training and for reporting) ---
1438
+ HALLMARKS = [
1439
+ "activating invasion and metastasis",
1440
+ "avoiding immune destruction",
1441
+ "cellular energetics",
1442
+ "enabling replicative immortality",
1443
+ "evading growth suppressors",
1444
+ "genomic instability and mutation",
1445
+ "inducing angiogenesis",
1446
+ "resisting cell death",
1447
+ "sustaining proliferative signaling",
1448
+ "tumor promoting inflammation",
1449
+ ]
1450
+
1451
+ def split_labels(s: str):
1452
+ s = (s or "").strip()
1453
+ if not s: return []
1454
+ for sep in [",", ";", "|"]:
1455
+ if sep in s:
1456
+ return [p.strip() for p in s.split(sep) if p.strip()]
1457
+ return [s]
1458
+
1459
+ def normalize_labels(labs):
1460
+ """Map raw labels (including fuzzy matches) to the 10 hallmarks; drop 'empty'."""
1461
+ keep = []
1462
+ low = [L.lower() for L in HALLMARKS]
1463
+ for x in labs:
1464
+ x_low = x.lower().strip()
1465
+ if x_low in low:
1466
+ keep.append(HALLMARKS[low.index(x_low)])
1467
+ else:
1468
+ best = difflib.get_close_matches(x_low, low, n=1, cutoff=0.7)
1469
+ if best:
1470
+ keep.append(HALLMARKS[low.index(best[0])])
1471
+ # dedupe & sort for deterministic target text
1472
+ return sorted(list(dict.fromkeys(keep)))
1473
+
1474
+ def labels_to_target_text(labs):
1475
+ labs = normalize_labels(labs)
1476
+ if len(labs) == 0:
1477
+ return None # -> drop from training if empty-only
1478
+ return ", ".join(labs)
1479
+
1480
+ """### 2.3 Redefine GPT architecture for full finetuning"""
1481
+
1482
+ # --- Your GPT modules (same as in your pretraining code) ---
1483
+ class LayerNorm(nn.Module):
1484
+ def __init__(self, ndim, bias):
1485
+ super().__init__()
1486
+ self.weight = nn.Parameter(torch.ones(ndim))
1487
+ self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
1488
+ def forward(self, x):
1489
+ return F.layer_norm(x, self.weight.shape, self.weight, self.bias, 1e-5)
1490
+
1491
+ class CausalSelfAttention(nn.Module):
1492
+ def __init__(self, config):
1493
+ super().__init__()
1494
+ assert config.n_embd % config.n_head == 0
1495
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
1496
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
1497
+ self.attn_dropout = nn.Dropout(config.dropout)
1498
+ self.resid_dropout = nn.Dropout(config.dropout)
1499
+ self.n_head = config.n_head
1500
+ self.n_embd = config.n_embd
1501
+ self.flash = hasattr(F, 'scaled_dot_product_attention')
1502
+ if not self.flash:
1503
+ self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
1504
+ .view(1, 1, config.block_size, config.block_size))
1505
+ def forward(self, x):
1506
+ B, T, C = x.size()
1507
+ q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
1508
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
1509
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
1510
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
1511
+ if self.flash:
1512
+ y = F.scaled_dot_product_attention(q, k, v, attn_mask=None,
1513
+ dropout_p=self.attn_dropout.p if self.training else 0.0,
1514
+ is_causal=True)
1515
+ else:
1516
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
1517
+ att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))
1518
+ att = F.softmax(att, dim=-1)
1519
+ att = self.attn_dropout(att)
1520
+ y = att @ v
1521
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
1522
+ y = self.resid_dropout(self.c_proj(y))
1523
+ return y
1524
+
1525
+ class MLP(nn.Module):
1526
+ def __init__(self, config):
1527
+ super().__init__()
1528
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
1529
+ self.gelu = nn.GELU()
1530
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
1531
+ self.dropout = nn.Dropout(config.dropout)
1532
+ def forward(self, x):
1533
+ return self.dropout(self.c_proj(self.gelu(self.c_fc(x))))
1534
+
1535
+ class Block(nn.Module):
1536
+ def __init__(self, config):
1537
+ super().__init__()
1538
+ self.ln1 = LayerNorm(config.n_embd, config.bias)
1539
+ self.attn = CausalSelfAttention(config)
1540
+ self.ln2 = LayerNorm(config.n_embd, config.bias)
1541
+ self.mlp = MLP(config)
1542
+ def forward(self, x):
1543
+ x = x + self.attn(self.ln1(x))
1544
+ x = x + self.mlp(self.ln2(x))
1545
+ return x
1546
+
1547
+ from dataclasses import dataclass
1548
+ @dataclass
1549
+ class GPTConfig:
1550
+ block_size: int
1551
+ vocab_size: int
1552
+ n_layer: int
1553
+ n_head: int
1554
+ n_embd: int
1555
+ dropout: float = 0.0
1556
+ bias: bool = True
1557
+
1558
+ class GPT(nn.Module):
1559
+ def __init__(self, config):
1560
+ super().__init__()
1561
+ self.config = config
1562
+ self.transformer = nn.ModuleDict(dict(
1563
+ wte=nn.Embedding(config.vocab_size, config.n_embd),
1564
+ wpe=nn.Embedding(config.block_size, config.n_embd),
1565
+ drop=nn.Dropout(config.dropout),
1566
+ h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
1567
+ ln_f=LayerNorm(config.n_embd, config.bias),
1568
+ ))
1569
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
1570
+ # weight tying
1571
+ self.transformer.wte.weight = self.lm_head.weight
1572
+
1573
+ self.apply(self._init_weights)
1574
+ for pn, p in self.named_parameters():
1575
+ if pn.endswith('c_proj.weight'):
1576
+ nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer))
1577
+
1578
+ def _init_weights(self, module):
1579
+ if isinstance(module, nn.Linear):
1580
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
1581
+ if module.bias is not None:
1582
+ nn.init.zeros_(module.bias)
1583
+ elif isinstance(module, nn.Embedding):
1584
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
1585
+
1586
+ def forward(self, idx, targets=None):
1587
+ device = idx.device
1588
+ B, T = idx.size()
1589
+ assert T <= self.config.block_size
1590
+ pos = torch.arange(0, T, dtype=torch.long, device=device)
1591
+ tok_emb = self.transformer.wte(idx)
1592
+ pos_emb = self.transformer.wpe(pos)
1593
+ x = self.transformer.drop(tok_emb + pos_emb)
1594
+ for block in self.transformer.h:
1595
+ x = block(x)
1596
+ x = self.transformer.ln_f(x)
1597
+ if targets is not None:
1598
+ logits = self.lm_head(x) # (B,T,V)
1599
+ loss = F.cross_entropy(
1600
+ logits.view(-1, logits.size(-1)),
1601
+ targets.view(-1),
1602
+ ignore_index=-1
1603
+ )
1604
+ return logits, loss
1605
+ else:
1606
+ logits = self.lm_head(x[:, [-1], :]) # (B,1,V)
1607
+ return logits, None
1608
+
1609
+ """### 2.4 Define Add SoftPrompt embeddings to input embeddings"""
1610
+
1611
+ class GPTWithSoftPrompt(nn.Module):
1612
+ def __init__(self, base_gpt: GPT, prompt_len=1):
1613
+ super().__init__()
1614
+ self.config = base_gpt.config
1615
+ self.transformer = base_gpt.transformer
1616
+ self.lm_head = base_gpt.lm_head
1617
+ C = self.config.n_embd
1618
+ self.soft_prompt = nn.Parameter(torch.zeros(1, prompt_len, C))
1619
+ nn.init.normal_(self.soft_prompt, mean=0.0, std=0.02)
1620
+
1621
+ def forward(self, idx, targets=None):
1622
+ B, T = idx.shape
1623
+ device = idx.device
1624
+
1625
+ # token + pos
1626
+ tok_emb = self.transformer.wte(idx) # (B,T,C)
1627
+ pos = torch.arange(0, T, dtype=torch.long, device=device)
1628
+ pos_emb = self.transformer.wpe(pos) # (T,C)
1629
+ x_tokens = tok_emb + pos_emb
1630
+
1631
+ # prepend soft prompt
1632
+ soft = self.soft_prompt.expand(B, -1, -1) # (B,P,C)
1633
+ x = torch.cat([soft, x_tokens], dim=1) # (B,P+T,C)
1634
+
1635
+ x = self.transformer.drop(x)
1636
+ for block in self.transformer.h:
1637
+ x = block(x)
1638
+ x = self.transformer.ln_f(x)
1639
+ logits = self.lm_head(x) # (B,P+T,V)
1640
+
1641
+ if targets is None:
1642
+ # return next-token logits at last (standard for generation)
1643
+ return logits[:, -1, :], None
1644
+
1645
+ # ----- FIX: next-token loss with soft-prompt padding -----
1646
+ P = soft.size(1)
1647
+ pad_ignore = torch.full((B, P), -1, dtype=targets.dtype, device=device) # ignore for soft prompt
1648
+ full_targets = torch.cat([pad_ignore, targets], dim=1) # (B,P+T)
1649
+
1650
+ # shift for next-token prediction
1651
+ logits_lm = logits[:, :-1, :].contiguous() # predict next token
1652
+ targets_lm = full_targets[:, 1:].contiguous()
1653
+
1654
+ loss = F.cross_entropy(
1655
+ logits_lm.view(-1, logits_lm.size(-1)),
1656
+ targets_lm.view(-1),
1657
+ ignore_index=-1
1658
+ )
1659
+ return logits, loss
1660
+
1661
+ """### 2.5 Instantiate pre-training weights"""
1662
+
1663
+ # --- Instantiate & (optionally) load your pretraining weights ---
1664
+ device = "cuda" if torch.cuda.is_available() else "cpu"
1665
+
1666
+ # Use your pretrain block_size (128 in your earlier run). If different, the loader below can resize wpe.
1667
+ BLOCK_SIZE = 128 # set to 128 if that was your pretrain; otherwise set to your pretrain context length
1668
+
1669
+ config = GPTConfig(
1670
+ vocab_size=len(token2id),
1671
+ block_size=BLOCK_SIZE,
1672
+ n_layer=6, n_head=6, n_embd=384,
1673
+ dropout=0.1, bias=True
1674
+ )
1675
+ base_gpt = GPT(config)
1676
+
1677
+ def load_with_wpe_resize(model, ckpt_path):
1678
+ sd = torch.load(ckpt_path, map_location="cpu")
1679
+ key = "transformer.wpe.weight"
1680
+ if key in sd:
1681
+ old = sd[key]
1682
+ new_len = model.transformer.wpe.weight.shape[0]
1683
+ if old.shape[0] != new_len:
1684
+ # copy existing, init the rest
1685
+ new = model.transformer.wpe.weight.data.clone()
1686
+ n = min(new_len, old.shape[0])
1687
+ new[:n] = old[:n]
1688
+ if new_len > n:
1689
+ nn.init.normal_(new[n:], mean=0.0, std=0.02)
1690
+ sd[key] = new
1691
+ missing, unexpected = model.load_state_dict(sd, strict=False)
1692
+ print("Loaded state dict with resize. Missing:", missing, "Unexpected:", unexpected)
1693
+
1694
+ pt_path = "best_model_params.pt"
1695
+ if os.path.exists(pt_path):
1696
+ load_with_wpe_resize(base_gpt, pt_path)
1697
+ print("Loaded pretraining weights from:", pt_path)
1698
+ else:
1699
+ print("No pretrain checkpoint found; training soft prompt from scratch on top of random GPT.")
1700
+
1701
+ model = GPTWithSoftPrompt(base_gpt, prompt_len=1).to(device)
1702
+
1703
+ """### 2.6 Build a mask of token IDs that are allowed during generation"""
1704
+
1705
+ # --- Constrained token mask (only hallmarks + separators + </s>) ---
1706
+ def build_allowed_token_mask(vocab_size, device):
1707
+ allowed = set()
1708
+ # hallmark token ids
1709
+ for bpe in bpe_encode_lines(HALLMARKS, desc="BPE hallmarks"):
1710
+ ids, _ = tokens_to_ids(bpe)
1711
+ allowed.update(ids)
1712
+ # separators
1713
+ for sep in [", ", ",", "; ", ";", "|", " and "]:
1714
+ bpe = bpe_encode_lines([sep], desc="BPE seps")[0]
1715
+ ids, _ = tokens_to_ids(bpe)
1716
+ allowed.update(ids)
1717
+ allowed.add(eos_id)
1718
+ mask = torch.full((vocab_size,), float('-inf'), device=device)
1719
+ mask[list(allowed)] = 0.0
1720
+ return mask
1721
+
1722
+ ALLOWED_MASK = build_allowed_token_mask(len(token2id), device)
1723
+
1724
+ """### 2.7:
1725
+
1726
+ - Define a dataset class that encodes abstracts and labels into token IDs (dropping empty-only rows for training if desired)
1727
+ - Concatenate them into input/target sequences respecting a block size
1728
+ - Provide a collate function to pad batches for training.
1729
+ """
1730
+
1731
+ # --- Dataset (drops empty-only rows for TRAIN to avoid collapse) ---
1732
+ class HoCGenDataset(Dataset):
1733
+ def __init__(self, df, block_size=256, drop_empty_only=False, name=""):
1734
+ self.block_size = block_size
1735
+ self.samples = []
1736
+
1737
+ texts = df["text"].astype(str).tolist()
1738
+ raw_labels = [split_labels(s) for s in df["label"].astype(str).tolist()]
1739
+
1740
+ # BPE encode texts with progress
1741
+ text_bpe = bpe_encode_lines(texts, shard_size=2000, desc=f"BPE {name or 'dataset'}")
1742
+
1743
+ # Pre-encode unique label targets
1744
+ targets = []
1745
+ for labs in raw_labels:
1746
+ tgt = labels_to_target_text(labs) # None if empty-only
1747
+ targets.append(tgt)
1748
+ uniq_non_null = sorted(set([t for t in targets if t is not None]))
1749
+
1750
+ label_cache = {}
1751
+ if len(uniq_non_null) > 0:
1752
+ encoded = bpe_encode_lines(uniq_non_null, shard_size=200, desc=f"BPE labels {name or 'dataset'}")
1753
+ for s, bpe in zip(uniq_non_null, encoded):
1754
+ ids, _ = tokens_to_ids(bpe)
1755
+ label_cache[s] = ids
1756
+
1757
+ # Pack samples
1758
+ for bpe, tgt in tqdm(list(zip(text_bpe, targets)), total=len(text_bpe), desc=f"Packing {name or 'dataset'}", leave=False):
1759
+ if drop_empty_only and tgt is None:
1760
+ continue
1761
+ text_ids, _ = tokens_to_ids(bpe)
1762
+
1763
+ if tgt is None:
1764
+ label_ids = []
1765
+ else:
1766
+ label_ids = label_cache[tgt]
1767
+
1768
+ x_ids = text_ids + [eos_id]
1769
+ y_ids = (label_ids + [eos_id]) if len(label_ids) > 0 else []
1770
+
1771
+ # respect block size
1772
+ max_text = self.block_size - (2 if len(y_ids) > 0 else 1) - len(y_ids)
1773
+ if max_text < 1:
1774
+ x_ids = x_ids[:max(1, self.block_size // 2)]
1775
+ else:
1776
+ x_ids = x_ids[:max_text]
1777
+
1778
+ input_ids = x_ids + y_ids
1779
+ targets_arr = ([-1] * len(x_ids)) + (y_ids if len(y_ids) > 0 else [])
1780
+
1781
+ self.samples.append((
1782
+ np.array(input_ids, dtype=np.int64),
1783
+ np.array(targets_arr, dtype=np.int64)
1784
+ ))
1785
+
1786
+ def __len__(self): return len(self.samples)
1787
+ def __getitem__(self, idx): return self.samples[idx]
1788
+
1789
+ def collate(batch):
1790
+ L = max(len(x[0]) for x in batch)
1791
+ B = len(batch)
1792
+ inputs = np.full((B, L), pad_id, dtype=np.int64)
1793
+ targets = np.full((B, L), -1, dtype=np.int64)
1794
+ for i, (inp, tgt) in enumerate(batch):
1795
+ n = len(inp)
1796
+ inputs[i, :n] = inp
1797
+ targets[i, :n] = tgt
1798
+ return torch.from_numpy(inputs), torch.from_numpy(targets)
1799
+
1800
+ """### 2.8 Create dataloaders for the finetuning dataset"""
1801
+
1802
+ # --- Datasets/Loaders ---
1803
+ BATCH_SIZE = 16
1804
+
1805
+ # Train: drop empty-only rows (crucial)
1806
+ train_ds = HoCGenDataset(train_df, block_size=model.config.block_size, drop_empty_only=True, name="train")
1807
+ # Valid: drop empty-only too (makes val loss meaningful)
1808
+ valid_ds = HoCGenDataset(valid_df, block_size=model.config.block_size, drop_empty_only=True, name="valid")
1809
+ # Test: keep all rows; we'll evaluate on the 10 hallmarks only later
1810
+ test_ds = HoCGenDataset(test_df, block_size=model.config.block_size, drop_empty_only=False, name="test")
1811
+
1812
+ cuda_gen = torch.Generator(device='cuda') # or set a manual seed if you want
1813
+
1814
+ train_loader = DataLoader(
1815
+ train_ds, batch_size=BATCH_SIZE, shuffle=True,
1816
+ collate_fn=collate, drop_last=True,
1817
+ generator=cuda_gen, # <-- key fix
1818
+ pin_memory=True, pin_memory_device='cuda'
1819
+ )
1820
+
1821
+ valid_loader = DataLoader(
1822
+ valid_ds, batch_size=BATCH_SIZE, shuffle=False,
1823
+ collate_fn=collate,
1824
+ generator=cuda_gen,
1825
+ pin_memory=True, pin_memory_device='cuda'
1826
+ )
1827
+
1828
+ test_loader = DataLoader(
1829
+ test_ds, batch_size=BATCH_SIZE, shuffle=False,
1830
+ collate_fn=collate,
1831
+ generator=cuda_gen,
1832
+ pin_memory=True, pin_memory_device='cuda'
1833
+ )
1834
+
1835
+ print(f"Train samples (non-empty only): {len(train_ds)}")
1836
+ print(f"Valid samples (non-empty only): {len(valid_ds)}")
1837
+ print(f"Test samples (incl. empty): {len(test_ds)}")
1838
+
1839
+ xb, yb = next(iter(train_loader))
1840
+ assert (yb != -1).any(), "No supervised label tokens in this batch — are we dropping all rows?"
1841
+
1842
+ xb, yb = xb.to(device), yb.to(device)
1843
+ with torch.no_grad():
1844
+ _, loss = model(xb, yb)
1845
+ print("Initial loss:", float(loss))
1846
+
1847
+ """### 2.9
1848
+
1849
+ - Feeds the current context into the model (self(ctx)).
1850
+
1851
+ - Adds the allowed_mask to the logits so that only permitted token IDs (Hallmarks, separators, </s>) can be chosen; all others get -inf and are impossible to sample.
1852
+
1853
+ - Picks the next token greedily (argmax) unless a temperature is set, in which case it samples.
1854
+
1855
+ - Forces already finished sequences to emit </s> and stops early when all sequences are finished.
1856
+ """
1857
+
1858
+ # --- Constrained, batched decoding method for GPTWithSoftPrompt ---
1859
+ def constrained_generate_labels(self, idx, allowed_mask, max_new_tokens=24, temperature=0.0):
1860
+ """
1861
+ Batched decode. At each step, mask logits to the allowed set.
1862
+ Returns only generated tail (B, Tgen).
1863
+ """
1864
+ self.eval()
1865
+ B = idx.size(0)
1866
+ out = idx.clone()
1867
+ finished = torch.zeros(B, dtype=torch.bool, device=idx.device)
1868
+
1869
+ for _ in range(max_new_tokens):
1870
+ ctx = out[:, -self.config.block_size:]
1871
+ logits, _ = self(ctx) # (B,V)
1872
+ # apply constraint
1873
+ logits = logits + allowed_mask
1874
+ if temperature <= 0:
1875
+ next_id = torch.argmax(logits, dim=-1) # (B,)
1876
+ else:
1877
+ probs = F.softmax(logits / temperature, dim=-1)
1878
+ next_id = torch.multinomial(probs, num_samples=1).squeeze(1)
1879
+
1880
+ next_id = next_id.masked_fill(finished, eos_id)
1881
+ out = torch.cat([out, next_id.unsqueeze(1)], dim=1)
1882
+ finished |= (next_id == eos_id)
1883
+ if bool(finished.all()):
1884
+ break
1885
+ return out[:, idx.size(1):]
1886
+
1887
+ # attach to instance/class
1888
+ GPTWithSoftPrompt.generate_labels = constrained_generate_labels
1889
+
1890
+ """### 2.10 Run the finetuning loop"""
1891
+
1892
+ # --- Optimizer & schedulers (paper: 20k steps, warmup 1k, peak LR 1e-5) ---
1893
+ max_steps = 20_000
1894
+ warmup = 1_000
1895
+ peak_lr = 1e-5
1896
+ eta_min = 1e-6
1897
+
1898
+ optimizer = torch.optim.AdamW(model.parameters(), lr=peak_lr, betas=(0.9, 0.95), weight_decay=0.01, eps=1e-9)
1899
+ sched_warm = LinearLR(optimizer, total_iters=warmup)
1900
+ sched_decay = CosineAnnealingLR(optimizer, T_max=max_steps - warmup, eta_min=eta_min)
1901
+ scheduler = SequentialLR(optimizer, [sched_warm, sched_decay], milestones=[warmup])
1902
+
1903
+ # AMP dtype: bf16 if supported, else fp16; enable GradScaler only if fp16
1904
+ amp_dtype = torch.bfloat16 if (torch.cuda.is_available() and torch.cuda.is_bf16_supported()) else torch.float16
1905
+ scaler = torch.cuda.amp.GradScaler(enabled=(amp_dtype == torch.float16))
1906
+
1907
+ def run_eval(loader):
1908
+ model.eval()
1909
+ losses = []
1910
+ with torch.no_grad():
1911
+ for xb, yb in tqdm(loader, desc="Valid", leave=False):
1912
+ xb, yb = xb.to(device), yb.to(device)
1913
+ with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=torch.cuda.is_available()):
1914
+ _, loss = model(xb, yb)
1915
+ losses.append(loss.item())
1916
+ model.train()
1917
+ return float(np.mean(losses)) if losses else 0.0
1918
+
1919
+ # --- Training loop ---
1920
+ EVAL_EVERY = 500
1921
+ BEST_PATH = "hoc_best.pt"
1922
+
1923
+ best_val = float('inf')
1924
+ global_step = 0
1925
+ ema_loss = None
1926
+ pbar = tqdm(total=max_steps, desc="Finetuning (HoC)", leave=True)
1927
+
1928
+ model.train()
1929
+ while global_step < max_steps:
1930
+ for xb, yb in train_loader:
1931
+ xb, yb = xb.to(device), yb.to(device)
1932
+ with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=torch.cuda.is_available()):
1933
+ _, loss = model(xb, yb)
1934
+
1935
+ scaler.scale(loss).backward()
1936
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
1937
+ scaler.step(optimizer)
1938
+ scaler.update()
1939
+ optimizer.zero_grad(set_to_none=True)
1940
+ scheduler.step()
1941
+
1942
+ global_step += 1
1943
+ pbar.update(1)
1944
+
1945
+ cur = loss.item()
1946
+ ema_loss = cur if ema_loss is None else (0.95 * ema_loss + 0.05 * cur)
1947
+ pbar.set_postfix({
1948
+ "train_loss": f"{cur:.3f}",
1949
+ "ema": f"{ema_loss:.3f}",
1950
+ "best_val": f"{best_val:.3f}" if best_val < float('inf') else "—",
1951
+ "lr": f"{optimizer.param_groups[0]['lr']:.2e}",
1952
+ })
1953
+
1954
+ if global_step % EVAL_EVERY == 0:
1955
+ val_loss = run_eval(valid_loader)
1956
+ if val_loss < best_val:
1957
+ best_val = val_loss
1958
+ torch.save(model.state_dict(), BEST_PATH)
1959
+ pbar.set_postfix({
1960
+ "train_loss": f"{cur:.3f}",
1961
+ "ema": f"{ema_loss:.3f}",
1962
+ "best_val": f"{best_val:.3f}",
1963
+ "lr": f"{optimizer.param_groups[0]['lr']:.2e}",
1964
+ })
1965
+
1966
+ if global_step >= max_steps:
1967
+ break
1968
+
1969
+ pbar.close()
1970
+
1971
+ # reload best
1972
+ if os.path.exists(BEST_PATH):
1973
+ model.load_state_dict(torch.load(BEST_PATH, map_location=device))
1974
+ print("Loaded best checkpoint:", BEST_PATH, " (val_loss:", f"{best_val:.4f}", ")")
1975
+
1976
+ """### 2.11 Classification evaluation"""
1977
+
1978
+ # --- Build context-only inputs (text </s>) directly from raw test_df ---
1979
+ def make_context_only(df):
1980
+ texts = df["text"].astype(str).tolist()
1981
+ bpes = bpe_encode_lines(texts, desc="BPE test ctx")
1982
+ ctx = []
1983
+ for bpe in bpes:
1984
+ ids, _ = tokens_to_ids(bpe)
1985
+ ctx.append(np.array(ids + [eos_id], dtype=np.int64))
1986
+ return ctx
1987
+
1988
+ def pad_batch(seqs):
1989
+ L = max(len(s) for s in seqs)
1990
+ out = np.full((len(seqs), L), pad_id, dtype=np.int64)
1991
+ for i, s in enumerate(seqs):
1992
+ out[i, :len(s)] = s
1993
+ return torch.from_numpy(out)
1994
+
1995
+ def ids_to_tokens(ids):
1996
+ return [id2token.get(int(i), "<unk>") for i in ids]
1997
+
1998
+ def to_canonical(pred_chunk: str):
1999
+ s = (pred_chunk or "").strip().lower()
2000
+ low = [L.lower() for L in HALLMARKS]
2001
+ if s in low:
2002
+ return HALLMARKS[low.index(s)]
2003
+ best = difflib.get_close_matches(s, low, n=1, cutoff=0.7)
2004
+ return HALLMARKS[low.index(best[0])] if best else None
2005
+
2006
+ def predict_labels_for_batch(xb):
2007
+ """xb: (B, T) context-only input ids (text </s>)."""
2008
+ with torch.no_grad():
2009
+ gens = model.generate_labels(xb, allowed_mask=ALLOWED_MASK, max_new_tokens=24, temperature=0.0)
2010
+ preds = []
2011
+ for g in gens:
2012
+ toks = ids_to_tokens(g.detach().cpu().numpy())
2013
+ # cut at EOS
2014
+ toks = toks[: toks.index("</s>")] if "</s>" in toks else toks
2015
+ label_str = bpe_decode_tokens(toks).strip().lower()
2016
+
2017
+ # split multi-label guesses
2018
+ parts = []
2019
+ for sep in [",", ";", "|"]:
2020
+ if sep in label_str:
2021
+ parts = [p.strip() for p in label_str.split(sep) if p.strip()]
2022
+ break
2023
+ if not parts:
2024
+ parts = [label_str] if label_str else []
2025
+
2026
+ # map to canonical hallmarks (no default to 'empty')
2027
+ mapped = []
2028
+ for p in parts:
2029
+ can = to_canonical(p)
2030
+ if can and can not in mapped:
2031
+ mapped.append(can)
2032
+ preds.append(mapped) # may be []
2033
+ return preds
2034
+
2035
+ # --- Run decoding on TEST ---
2036
+ model.eval()
2037
+ ctx_test = make_context_only(test_df)
2038
+
2039
+ B = 32
2040
+ preds_all = []
2041
+ for i in tqdm(range(0, len(ctx_test), B), desc="Decoding (test)"):
2042
+ batch_ctx = pad_batch(ctx_test[i:i+B]).to(device)
2043
+ preds_all.extend(predict_labels_for_batch(batch_ctx))
2044
+
2045
+ # --- Build ground truth (hallmarks only) ---
2046
+ y_true = [ normalize_labels(split_labels(s)) for s in test_df["label"].astype(str).tolist() ]
2047
+
2048
+ # --- Binarize and score (10 hallmarks only) ---
2049
+ from sklearn.metrics import precision_recall_fscore_support
2050
+ LABELS = HALLMARKS
2051
+ LIDX = {l:i for i,l in enumerate(LABELS)}
2052
+
2053
+ def binarize(labs):
2054
+ v = [0]*len(LABELS)
2055
+ for l in labs:
2056
+ if l in LIDX:
2057
+ v[LIDX[l]] = 1
2058
+ return v
2059
+
2060
+ Y_true = np.array([binarize(labs) for labs in y_true], dtype=np.int64)
2061
+ Y_pred = np.array([binarize(labs) for labs in preds_all], dtype=np.int64)
2062
+
2063
+ micro_p, micro_r, micro_f1, _ = precision_recall_fscore_support(Y_true, Y_pred, average='micro', zero_division=0)
2064
+ macro_p, macro_r, macro_f1, _ = precision_recall_fscore_support(Y_true, Y_pred, average='macro', zero_division=0)
2065
+
2066
+ print(f"\nHALLMARKS-ONLY Micro P/R/F1: {micro_p:.4f} / {micro_r:.4f} / {micro_f1:.4f}")
2067
+ print( f"HALLMARKS-ONLY Macro P/R/F1: {macro_p:.4f} / {macro_r:.4f} / {macro_f1:.4f}")
2068
+
2069
+ perclass = precision_recall_fscore_support(Y_true, Y_pred, average=None, zero_division=0)
2070
+ per_df = pd.DataFrame({
2071
+ "label": LABELS,
2072
+ "precision": perclass[0],
2073
+ "recall": perclass[1],
2074
+ "f1": perclass[2],
2075
+ "support": perclass[3],
2076
+ }).sort_values("label")
2077
+
2078
+ print("\nPer-class results (10 hallmarks):")
2079
+ print(per_df.to_string(index=False))
2080
+
2081
+ per_df.to_csv("hoc_test_results_per_class.csv", index=False)
2082
+ print("Saved: hoc_test_results_per_class.csv")