rag / modules /reranker.py
dmytrotm's picture
initial commit
2303ac9
import sys
from pathlib import Path
from typing import List, Dict
# Add project root to sys.path
sys.path.append(str(Path(__file__).parent.parent))
import config
from sentence_transformers import CrossEncoder
class Reranker:
def __init__(self, model_name: str = None):
self.model_name = model_name if model_name else config.RERANKER_MODEL_NAME
print(f"Loading Reranker Model: {self.model_name}...")
self.model = CrossEncoder(self.model_name)
def rank(self, query: str, chunks: List[Dict], top_k: int = 5) -> List[Dict]:
"""
Reranks a list of chunks based on the query.
Returns the top_k chunks with updated scores.
"""
import time
start_time = time.time()
# Ensure top_k is an integer (handle Gradio sliders passing floats)
top_k = int(top_k)
if not chunks:
return []
print(f"Reranking {len(chunks)} chunks on {self.model.device}...")
# Prepare pairs for cross-encoder
# (Query, Title + Text)
pairs = []
for res in chunks:
chunk = res['chunk']
meta = chunk['metadata']
# Include metadata in the reranking context for better accuracy
text_context = f"{meta.get('citation_short', '')} {meta.get('article_title', '')} {chunk['text']}"
pairs.append([query, text_context])
# Predict scores
# We can specify batch size if memory is still tight
scores = self.model.predict(pairs, batch_size=4, show_progress_bar=True)
# Update scores in results
for i, res in enumerate(chunks):
res['rerank_score'] = float(scores[i])
# Sort by rerank score
ranked_chunks = sorted(chunks, key=lambda x: x['rerank_score'], reverse=True)
duration = time.time() - start_time
print(f"Reranking completed in {duration:.2f} seconds.")
return ranked_chunks[:top_k]