|
|
import sys |
|
|
from pathlib import Path |
|
|
from typing import List, Dict |
|
|
|
|
|
|
|
|
sys.path.append(str(Path(__file__).parent.parent)) |
|
|
import config |
|
|
|
|
|
from sentence_transformers import CrossEncoder |
|
|
|
|
|
class Reranker: |
|
|
def __init__(self, model_name: str = None): |
|
|
self.model_name = model_name if model_name else config.RERANKER_MODEL_NAME |
|
|
print(f"Loading Reranker Model: {self.model_name}...") |
|
|
self.model = CrossEncoder(self.model_name) |
|
|
|
|
|
def rank(self, query: str, chunks: List[Dict], top_k: int = 5) -> List[Dict]: |
|
|
""" |
|
|
Reranks a list of chunks based on the query. |
|
|
Returns the top_k chunks with updated scores. |
|
|
""" |
|
|
import time |
|
|
start_time = time.time() |
|
|
|
|
|
|
|
|
top_k = int(top_k) |
|
|
|
|
|
if not chunks: |
|
|
return [] |
|
|
|
|
|
print(f"Reranking {len(chunks)} chunks on {self.model.device}...") |
|
|
|
|
|
|
|
|
|
|
|
pairs = [] |
|
|
for res in chunks: |
|
|
chunk = res['chunk'] |
|
|
meta = chunk['metadata'] |
|
|
|
|
|
text_context = f"{meta.get('citation_short', '')} {meta.get('article_title', '')} {chunk['text']}" |
|
|
pairs.append([query, text_context]) |
|
|
|
|
|
|
|
|
|
|
|
scores = self.model.predict(pairs, batch_size=4, show_progress_bar=True) |
|
|
|
|
|
|
|
|
for i, res in enumerate(chunks): |
|
|
res['rerank_score'] = float(scores[i]) |
|
|
|
|
|
|
|
|
ranked_chunks = sorted(chunks, key=lambda x: x['rerank_score'], reverse=True) |
|
|
|
|
|
duration = time.time() - start_time |
|
|
print(f"Reranking completed in {duration:.2f} seconds.") |
|
|
|
|
|
return ranked_chunks[:top_k] |
|
|
|