import sys from pathlib import Path from typing import List, Dict # Add project root to sys.path sys.path.append(str(Path(__file__).parent.parent)) import config from sentence_transformers import CrossEncoder class Reranker: def __init__(self, model_name: str = None): self.model_name = model_name if model_name else config.RERANKER_MODEL_NAME print(f"Loading Reranker Model: {self.model_name}...") self.model = CrossEncoder(self.model_name) def rank(self, query: str, chunks: List[Dict], top_k: int = 5) -> List[Dict]: """ Reranks a list of chunks based on the query. Returns the top_k chunks with updated scores. """ import time start_time = time.time() # Ensure top_k is an integer (handle Gradio sliders passing floats) top_k = int(top_k) if not chunks: return [] print(f"Reranking {len(chunks)} chunks on {self.model.device}...") # Prepare pairs for cross-encoder # (Query, Title + Text) pairs = [] for res in chunks: chunk = res['chunk'] meta = chunk['metadata'] # Include metadata in the reranking context for better accuracy text_context = f"{meta.get('citation_short', '')} {meta.get('article_title', '')} {chunk['text']}" pairs.append([query, text_context]) # Predict scores # We can specify batch size if memory is still tight scores = self.model.predict(pairs, batch_size=4, show_progress_bar=True) # Update scores in results for i, res in enumerate(chunks): res['rerank_score'] = float(scores[i]) # Sort by rerank score ranked_chunks = sorted(chunks, key=lambda x: x['rerank_score'], reverse=True) duration = time.time() - start_time print(f"Reranking completed in {duration:.2f} seconds.") return ranked_chunks[:top_k]