File size: 2,038 Bytes
2303ac9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import sys
from pathlib import Path
from typing import List, Dict

# Add project root to sys.path
sys.path.append(str(Path(__file__).parent.parent))
import config

from sentence_transformers import CrossEncoder

class Reranker:
    def __init__(self, model_name: str = None):
        self.model_name = model_name if model_name else config.RERANKER_MODEL_NAME
        print(f"Loading Reranker Model: {self.model_name}...")
        self.model = CrossEncoder(self.model_name)

    def rank(self, query: str, chunks: List[Dict], top_k: int = 5) -> List[Dict]:
        """
        Reranks a list of chunks based on the query.
        Returns the top_k chunks with updated scores.
        """
        import time
        start_time = time.time()
        
        # Ensure top_k is an integer (handle Gradio sliders passing floats)
        top_k = int(top_k)
        
        if not chunks:
            return []
            
        print(f"Reranking {len(chunks)} chunks on {self.model.device}...")
        
        # Prepare pairs for cross-encoder
        # (Query, Title + Text)
        pairs = []
        for res in chunks:
            chunk = res['chunk']
            meta = chunk['metadata']
            # Include metadata in the reranking context for better accuracy
            text_context = f"{meta.get('citation_short', '')} {meta.get('article_title', '')} {chunk['text']}"
            pairs.append([query, text_context])
            
        # Predict scores
        # We can specify batch size if memory is still tight
        scores = self.model.predict(pairs, batch_size=4, show_progress_bar=True)
        
        # Update scores in results
        for i, res in enumerate(chunks):
            res['rerank_score'] = float(scores[i])
            
        # Sort by rerank score
        ranked_chunks = sorted(chunks, key=lambda x: x['rerank_score'], reverse=True)
        
        duration = time.time() - start_time
        print(f"Reranking completed in {duration:.2f} seconds.")
        
        return ranked_chunks[:top_k]