import json import torch import numpy as np from pathlib import Path import re import inspect import sys import urllib.request sys.path.append(str(Path(__file__).parent.parent)) import config import pymorphy3 from symspellpy import SymSpell, Verbosity from rank_bm25 import BM25Okapi from sentence_transformers import SentenceTransformer, util from typing import List, Dict, Optional from functools import lru_cache class Retriever: """ Hybrid retrieval system combining BM25 and semantic search. Supports Ukrainian text processing with spell checking and lemmatization. """ def __init__( self, data_path: Optional[str] = None, model_name: Optional[str] = None, enable_spell_check: bool = True, spell_check_threshold: float = 0.5, max_edit_distance: int = 2 ): """ Initialize retriever with data and embedding model. Args: data_path: Path to chunks JSON file model_name: Name of sentence transformer model enable_spell_check: Enable spell checking for unknown words spell_check_threshold: Confidence threshold for spell checking (0-1) max_edit_distance: Maximum edit distance for spell corrections """ self.data_path = Path(data_path) if data_path else config.CHUNKS_PATH self.model_name = model_name or config.EMBEDDING_MODEL_NAME self.enable_spell_check = enable_spell_check self.spell_check_threshold = spell_check_threshold self.max_edit_distance = max_edit_distance # Load chunks self._load_chunks() # Initialize components self._init_nlp_tools() self._init_bm25() self._init_semantic_model() def _load_chunks(self) -> None: """Load text chunks from JSON file, filtering out stub articles.""" print(f"Loading data from {self.data_path}...") with open(self.data_path, 'r', encoding='utf-8') as f: all_chunks = json.load(f) # Filter out stub chunks (e.g., "Стаття 207." with no content) self.chunks = [ chunk for chunk in all_chunks if len(chunk.get('text', '')) >= config.MIN_CHUNK_LENGTH ] filtered_count = len(all_chunks) - len(self.chunks) self.corpus_texts = [chunk['text'] for chunk in self.chunks] print(f"Loaded {len(self.chunks)} chunks (filtered {filtered_count} stubs)") def _download_ukrainian_dictionary(self, dict_path: Path) -> None: """Download Ukrainian frequency dictionary for spell checking.""" print("Downloading Ukrainian dictionary for spell checking...") url = "https://raw.githubusercontent.com/brown-uk/dict_uk/master/data/dict/uk_words.txt" try: dict_path.parent.mkdir(parents=True, exist_ok=True) urllib.request.urlretrieve(url, dict_path) print(f"Dictionary downloaded to {dict_path}") except Exception as e: print(f"Warning: Could not download dictionary: {e}") print("Creating basic dictionary from corpus...") self._create_corpus_dictionary(dict_path) def _create_corpus_dictionary(self, dict_path: Path) -> None: """Create a basic frequency dictionary from corpus.""" from collections import Counter # Extract all words from corpus all_words = [] for text in self.corpus_texts: tokens = re.findall(r'\w+', text.lower()) all_words.extend(tokens) # Count frequencies word_freq = Counter(all_words) # Save in SymSpell format: word frequency dict_path.parent.mkdir(parents=True, exist_ok=True) with open(dict_path, 'w', encoding='utf-8') as f: for word, freq in word_freq.most_common(): f.write(f"{word} {freq}\n") print(f"Created corpus-based dictionary with {len(word_freq)} words") def _init_nlp_tools(self) -> None: """Initialize Ukrainian morphological analyzer and spell checker.""" print("Initializing Ukrainian NLP tools...") # Morphological analyzer self.morph = pymorphy3.MorphAnalyzer(lang='uk') # Spell checker (SymSpell - no Java required!) if self.enable_spell_check: print("Initializing spell checker...") try: self.spell_checker = SymSpell(max_dictionary_edit_distance=self.max_edit_distance) # Dictionary path dict_path = Path("data/uk_dictionary.txt") # Download or create dictionary if not exists if not dict_path.exists(): self._download_ukrainian_dictionary(dict_path) # Load dictionary if self.spell_checker.load_dictionary( str(dict_path), term_index=0, count_index=1, separator=" " ): print(f"Spell checker ready with {len(self.spell_checker.words)} words") else: print("Warning: Could not load dictionary") self.enable_spell_check = False except Exception as e: print(f"Warning: Could not initialize spell checker: {e}") print("Continuing without spell checking...") self.enable_spell_check = False def _init_bm25(self) -> None: """Initialize BM25 index with enriched text.""" print("Building BM25 index...") self.tokenized_corpus = [] for chunk in self.chunks: # Enrich text with metadata for better retrieval meta = chunk.get('metadata', {}) enriched_text = " ".join([ meta.get('citation_short', ''), meta.get('article_title', ''), chunk['text'] ]) self.tokenized_corpus.append(self._tokenize_and_lemmatize(enriched_text)) self.bm25 = BM25Okapi(self.tokenized_corpus) print("BM25 index ready") def _init_semantic_model(self) -> None: """Initialize semantic embeddings model and load/compute embeddings.""" print(f"Loading semantic model: {self.model_name}...") # If memory is tight, we might want to put embedding model on CPU # to leave room for the heavy reranker on MPS (if available) device = "cpu" if torch.backends.mps.is_available(): # For now, let's keep it on CPU to save MPS memory for reranker # Embedding is fast enough on CPU for small batches print("Placing embedding model on CPU to reserve MPS for reranker.") device = "cpu" self.model = SentenceTransformer(self.model_name, device=device) embed_path = config.EMBEDDINGS_PATH if embed_path.exists() and self._validate_embeddings(embed_path): print("Loading cached embeddings...") self.embeddings = torch.load(embed_path) # Ensure embeddings are on the same device as the model self.embeddings = self.embeddings.to(self.model.device) else: self._compute_and_save_embeddings(embed_path) def _validate_embeddings(self, path: Path) -> bool: """Check if cached embeddings match current corpus.""" try: embeddings = torch.load(path) return len(embeddings) == len(self.chunks) except Exception as e: print(f"Error loading embeddings: {e}") return False def _compute_and_save_embeddings(self, save_path: Path) -> None: """Compute embeddings for all corpus texts and save to disk.""" print("Computing embeddings (this may take a while)...") self.embeddings = self.model.encode( self.corpus_texts, convert_to_tensor=True, show_progress_bar=True, batch_size=32 ) save_path.parent.mkdir(parents=True, exist_ok=True) torch.save(self.embeddings, save_path) print(f"Embeddings saved to {save_path}") @lru_cache(maxsize=10000) def _correct_word(self, word: str) -> str: """ Correct spelling of a single word using SymSpell. Results are cached for performance. Args: word: Word to correct Returns: Corrected word """ if not self.enable_spell_check: return word try: # Look up suggestions suggestions = self.spell_checker.lookup( word, Verbosity.CLOSEST, max_edit_distance=self.max_edit_distance, include_unknown=False ) if suggestions and suggestions[0].term != word: corrected = suggestions[0].term print(f"Spell correction: '{word}' → '{corrected}' (distance: {suggestions[0].distance})") return corrected except Exception as e: # Silently fail and return original word pass return word def _tokenize_and_lemmatize(self, text: str) -> List[str]: """ Tokenize, spell-check (if enabled), and lemmatize Ukrainian text. Args: text: Input text Returns: List of lemmatized tokens """ # Extract words (Cyrillic, Latin, numbers) tokens = re.findall(r'\w+', text.lower()) lemmas = [] for token in tokens: # Skip very short tokens if len(token) < 2: lemmas.append(token) continue # Parse with pymorphy3 parsed = self.morph.parse(token)[0] # If word is unknown (low confidence) and spell checking is enabled if self.enable_spell_check and parsed.score < self.spell_check_threshold: # Try to correct spelling corrected_token = self._correct_word(token) # Re-parse corrected word if corrected_token != token: parsed = self.morph.parse(corrected_token)[0] lemmas.append(parsed.normal_form) return lemmas def search( self, query: str, top_k: int = 30, method: str = 'hybrid', alpha: float = None, # Uses config.HYBRID_ALPHA if None legal_area: str = None ) -> List[Dict]: """ Search for relevant chunks using specified method. Args: query: Search query top_k: Number of results to return method: 'bm25', 'semantic', or 'hybrid' alpha: Weight for hybrid search (semantic weight, 0-1) legal_area: Optional filter by metadata['legal_area'] Returns: List of result dictionaries with chunk and score """ # Ensure top_k is an integer (handle Gradio sliders passing floats) top_k = int(top_k) # Pre-filter chunks by legal_area if provided filtered_indices = None if legal_area and legal_area != "Всі": filtered_indices = [ i for i, chunk in enumerate(self.chunks) if chunk.get('metadata', {}).get('legal_area') == legal_area ] if not filtered_indices: return [] method_map = { 'bm25': self._search_bm25, 'semantic': self._search_semantic, 'hybrid': self._search_hybrid } if method not in method_map: raise ValueError(f"Unknown method: {method}. Use 'bm25', 'semantic', or 'hybrid'") search_func = method_map[method] if method == 'hybrid': effective_alpha = alpha if alpha is not None else config.HYBRID_ALPHA return search_func(query, top_k, effective_alpha, filtered_indices=filtered_indices) return search_func(query, top_k, filtered_indices=filtered_indices) def _search_bm25(self, query: str, top_k: int, filtered_indices: List[int] = None) -> List[Dict]: """Keyword-based BM25 search with spell correction.""" tokenized_query = self._tokenize_and_lemmatize(query) scores = self.bm25.get_scores(tokenized_query) if filtered_indices is not None: # Mask scores for non-filtered chunks mask = np.zeros(len(scores), dtype=bool) mask[filtered_indices] = True scores[~mask] = -1e9 top_indices = np.argsort(scores)[::-1][:top_k] # Filter out masked scores from results return [ { 'chunk': self.chunks[idx], 'score': float(scores[idx]), 'method': 'bm25' } for idx in top_indices if scores[idx] > -1e8 ] def _search_semantic(self, query: str, top_k: int, filtered_indices: List[int] = None) -> List[Dict]: """Semantic similarity search using embeddings.""" query_embedding = self.model.encode(query, convert_to_tensor=True) if filtered_indices is not None: # Filter embeddings before searching filtered_embeddings = self.embeddings[filtered_indices] hits = util.semantic_search(query_embedding, filtered_embeddings, top_k=top_k)[0] return [ { 'chunk': self.chunks[filtered_indices[hit['corpus_id']]], 'score': float(hit['score']), 'method': 'semantic' } for hit in hits ] hits = util.semantic_search(query_embedding, self.embeddings, top_k=top_k)[0] return [ { 'chunk': self.chunks[hit['corpus_id']], 'score': float(hit['score']), 'method': 'semantic' } for hit in hits ] def _search_hybrid(self, query: str, top_k: int, alpha: float, filtered_indices: List[int] = None) -> List[Dict]: """ Hybrid search combining BM25 and semantic similarity. Score = alpha * semantic_score + (1 - alpha) * bm25_score """ # Get BM25 scores tokenized_query = self._tokenize_and_lemmatize(query) bm25_scores = self.bm25.get_scores(tokenized_query) # Get semantic scores query_embedding = self.model.encode(query, convert_to_tensor=True) semantic_scores = util.cos_sim(query_embedding, self.embeddings)[0].cpu().numpy() # Mask scores if filtered_indices is provided if filtered_indices is not None: mask = np.zeros(len(self.chunks), dtype=bool) mask[filtered_indices] = True bm25_scores[~mask] = 0.0 # BM25 min is typically 0 semantic_scores[~mask] = -1.0 # Cosine min is -1 # Normalize scores to [0, 1] bm25_norm = self._min_max_normalize(bm25_scores) semantic_norm = self._min_max_normalize(semantic_scores) # Combine with weighted sum combined_scores = alpha * semantic_norm + (1 - alpha) * bm25_norm # Re-apply mask after normalization just in case if filtered_indices is not None: mask = np.zeros(len(self.chunks), dtype=bool) mask[filtered_indices] = True combined_scores[~mask] = -1.0 # Apply BM25 threshold: penalize chunks with no keyword overlap bm25_max = np.max(bm25_scores) if np.max(bm25_scores) > 0 else 1.0 bm25_relative = bm25_scores / bm25_max low_bm25_mask = bm25_relative < config.MIN_BM25_SCORE combined_scores[low_bm25_mask] *= 0.5 # Penalize rather than exclude # Get top-k results top_indices = np.argsort(combined_scores)[::-1][:top_k] return [ { 'chunk': self.chunks[idx], 'score': float(combined_scores[idx]), 'bm25_score': float(bm25_scores[idx]), 'semantic_score': float(semantic_scores[idx]), 'method': 'hybrid' } for idx in top_indices if combined_scores[idx] > -0.9 ] @staticmethod def _min_max_normalize(scores: np.ndarray) -> np.ndarray: """Normalize scores to [0, 1] range using min-max scaling.""" min_score = np.min(scores) max_score = np.max(scores) if max_score == min_score: return scores return (scores - min_score) / (max_score - min_score)