Complete RAG HW requirements: add legal area filtering, refine citations, and update documentation
Browse files- .gitignore +2 -1
- README.md +47 -40
- app.py +43 -8
- assets/style.css +20 -2
- config.py +15 -14
- modules/rag_system.py +4 -3
- modules/retriever.py +56 -8
.gitignore
CHANGED
|
@@ -5,4 +5,5 @@ __pycache__
|
|
| 5 |
*debug*
|
| 6 |
*test*
|
| 7 |
*verify*
|
| 8 |
-
*example*
|
|
|
|
|
|
| 5 |
*debug*
|
| 6 |
*test*
|
| 7 |
*verify*
|
| 8 |
+
*example*
|
| 9 |
+
*check*
|
README.md
CHANGED
|
@@ -9,59 +9,66 @@ app_file: app.py
|
|
| 9 |
pinned: false
|
| 10 |
---
|
| 11 |
|
| 12 |
-
# 🇺🇦
|
| 13 |
|
| 14 |
-
|
| 15 |
-
This system answers questions based on the Criminal Code, Civil Code, and other legal documents, providing precise citations for every answer.
|
| 16 |
|
| 17 |
-
##
|
| 18 |
-
-
|
| 19 |
-
- **Reranking**:
|
| 20 |
-
-
|
| 21 |
-
-
|
|
|
|
| 22 |
|
| 23 |
-
##
|
| 24 |
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
```bash
|
| 27 |
pip install -r requirements.txt
|
| 28 |
```
|
| 29 |
|
| 30 |
-
2.
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
python scripts/parser.py
|
| 35 |
```
|
| 36 |
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
The interface will be available at `http://localhost:7860`.
|
| 43 |
-
|
| 44 |
-
## Configuration
|
| 45 |
-
- Open `config.py` to change models or default parameters.
|
| 46 |
-
- You will need a **Groq API Key** to generate answers. Get one at [console.groq.com](https://console.groq.com).
|
| 47 |
|
| 48 |
-
##
|
| 49 |
```
|
| 50 |
-
|
| 51 |
-
├──
|
| 52 |
-
├──
|
| 53 |
-
├──
|
| 54 |
-
├── data/ # Data storage (chunks, embeddings)
|
| 55 |
├── modules/
|
| 56 |
-
│ ├── rag_system.py #
|
| 57 |
-
│ ├── retriever.py #
|
| 58 |
-
│ ├── reranker.py # Cross-Encoder
|
| 59 |
-
│ └── llm_handler.py # LiteLLM
|
| 60 |
└── scripts/
|
| 61 |
-
└── parser.py #
|
| 62 |
```
|
| 63 |
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
2. **Concept-based**: `Як звільнити працівника за прогул` (Will find relevant Labor Code articles)
|
| 67 |
-
3. **Complex**: `Яка різниця між крадіжкою і грабежем?`
|
|
|
|
| 9 |
pinned: false
|
| 10 |
---
|
| 11 |
|
| 12 |
+
# 🇺🇦 Асистент із Законодавства України (RAG QA)
|
| 13 |
|
| 14 |
+
Ця система дозволяє користувачам отримувати відповіді на юридичні запитання, базуючись на актуальних кодексах та законах України за допомогою підходу **Retrieval-Augmented Generation (RAG)**.
|
|
|
|
| 15 |
|
| 16 |
+
## Основні функції
|
| 17 |
+
- **Гібридний пошук**: Поєднання BM25 (ключові слова) та SBERT (семантика) для максимального охоплення.
|
| 18 |
+
- **Reranking**: Використання Cross-Encoder моделі для високої точності.
|
| 19 |
+
- **Цитування**: Автоматичне посилання на статті кодексів у тексті відповіді.
|
| 20 |
+
- **Фільтрація за метаданими**: Можливість звузити пошук до конкретної галузі права (Кримінальне, Цивільне тощо).
|
| 21 |
+
- **Інтелектуальні відповіді**: Використання **Llama 3.3 70B** для генерації відповідей українською мовою.
|
| 22 |
|
| 23 |
+
## Технічна архітектура
|
| 24 |
|
| 25 |
+
- **Retriever**:
|
| 26 |
+
- **BM25**: Лематизація (pymorphy3) та спелчекінг (SymSpell).
|
| 27 |
+
- **Semantic Search**: Модель `sentence-transformers/paraphrase-multilingual-mpnet-base-v2`.
|
| 28 |
+
- **Reranker**: Модель `cross-encoder/ms-marco-TinyBERT-L-2-v2`.
|
| 29 |
+
- **LLM**: `llama-3.3-70b-versatile` через API Groq (LiteLLM).
|
| 30 |
+
- **Metadata**: Реалізовано пере-фільтрацію за полем `legal_area`.
|
| 31 |
+
|
| 32 |
+
## Порівняння методів пошуку
|
| 33 |
+
|
| 34 |
+
| Запит | Кращий метод | Чому саме він? |
|
| 35 |
+
| :--- | :--- | :--- |
|
| 36 |
+
| "Стаття 115 ККУ" | **BM25** | Точний збіг за номером статті та назвою кодексу. |
|
| 37 |
+
| "права батьків після розлучення" | **Semantic Search** | Розуміє концепцію "сімейних прав", навіть якщо ці слова не зустрічаються буквально. |
|
| 38 |
+
|
| 39 |
+
## Інсталяція та запуск
|
| 40 |
+
|
| 41 |
+
1. **Встановлення залежностей**:
|
| 42 |
```bash
|
| 43 |
pip install -r requirements.txt
|
| 44 |
```
|
| 45 |
|
| 46 |
+
2. **Налаштування середовища**:
|
| 47 |
+
Створіть файл `.env` та додайте ваш ключ:
|
| 48 |
+
```env
|
| 49 |
+
GROQ_API_KEY=gsk_...
|
|
|
|
| 50 |
```
|
| 51 |
|
| 52 |
+
3. **Запуск**:
|
| 53 |
+
```bash
|
| 54 |
+
python app.py
|
| 55 |
+
```
|
| 56 |
+
Інтерфейс буде доступний за адресою `http://localhost:7860`.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
|
| 58 |
+
## Структура проєкту
|
| 59 |
```
|
| 60 |
+
├── app.py # Точка входу Gradio UI
|
| 61 |
+
├── config.py # Конфігурація та системні промпти
|
| 62 |
+
├── requirements.txt # Залежності
|
| 63 |
+
├── data/ # Дані (парсений JSON, ембедінги)
|
|
|
|
| 64 |
├── modules/
|
| 65 |
+
│ ├── rag_system.py # Оркестратор пайплайну
|
| 66 |
+
│ ├── retriever.py # Гібридний пошук з фільтрацією
|
| 67 |
+
│ ├── reranker.py # Cross-Encoder ранжування
|
| 68 |
+
│ └── llm_handler.py # Інтеграція з LLM (LiteLLM)
|
| 69 |
└── scripts/
|
| 70 |
+
└── parser.py # Парсер та передобробка документів
|
| 71 |
```
|
| 72 |
|
| 73 |
+
---
|
| 74 |
+
**Виконав**: Dmytro
|
|
|
|
|
|
app.py
CHANGED
|
@@ -34,7 +34,7 @@ def format_sources(sources):
|
|
| 34 |
html += "</div>"
|
| 35 |
return html
|
| 36 |
|
| 37 |
-
def run_chat(query, api_key, search_method, use_reranker, top_k, temperature):
|
| 38 |
if not query.strip():
|
| 39 |
return "Будь ласка, введіть запитання.", ""
|
| 40 |
|
|
@@ -45,7 +45,21 @@ def run_chat(query, api_key, search_method, use_reranker, top_k, temperature):
|
|
| 45 |
"🧠 Семантичний": "semantic"
|
| 46 |
}
|
| 47 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
internal_method = method_map.get(search_method, "hybrid")
|
|
|
|
| 49 |
|
| 50 |
try:
|
| 51 |
answer, sources = rag_system.process_query(
|
|
@@ -54,7 +68,8 @@ def run_chat(query, api_key, search_method, use_reranker, top_k, temperature):
|
|
| 54 |
use_reranker=use_reranker,
|
| 55 |
top_k_rerank=int(top_k),
|
| 56 |
temperature=float(temperature),
|
| 57 |
-
search_method=internal_method
|
|
|
|
| 58 |
)
|
| 59 |
|
| 60 |
sources_html = format_sources(sources)
|
|
@@ -63,7 +78,11 @@ def run_chat(query, api_key, search_method, use_reranker, top_k, temperature):
|
|
| 63 |
return f"Помилка при обробці запиту: {str(e)}", ""
|
| 64 |
|
| 65 |
# --- Gradio UI Construction ---
|
| 66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
# Header
|
| 68 |
with gr.Row(elem_classes="header-container"):
|
| 69 |
with gr.Column(scale=0, min_width=80):
|
|
@@ -97,8 +116,25 @@ with gr.Blocks(title="Асистент із Законодавства") as demo
|
|
| 97 |
value="🔄 Гібридний (Рекомендовано)"
|
| 98 |
)
|
| 99 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
with gr.Accordion("🛠️ Розширені параметри", open=False):
|
| 101 |
-
use_reranker = gr.Checkbox(label="Використовувати Reranker", value=True)
|
| 102 |
top_k = gr.Slider(label="Кількість джерел", minimum=1, maximum=20, step=1, value=config.DEFAULT_TOP_K_RERANK)
|
| 103 |
temperature = gr.Slider(label="Температура генерації", minimum=0.0, maximum=1.0, step=0.1, value=0.5)
|
| 104 |
|
|
@@ -129,17 +165,16 @@ with gr.Blocks(title="Асистент із Законодавства") as demo
|
|
| 129 |
btn = gr.Button(q, elem_classes="example-btn")
|
| 130 |
btn.click(lambda x=q: x, outputs=[query_input]).then(
|
| 131 |
fn=run_chat,
|
| 132 |
-
inputs=[query_input, api_key_input, search_method, use_reranker, top_k, temperature],
|
| 133 |
outputs=[output_answer, output_sources]
|
| 134 |
)
|
| 135 |
|
| 136 |
# --- Interactions ---
|
| 137 |
submit_btn.click(
|
| 138 |
fn=run_chat,
|
| 139 |
-
inputs=[query_input, api_key_input, search_method, use_reranker, top_k, temperature],
|
| 140 |
outputs=[output_answer, output_sources]
|
| 141 |
)
|
| 142 |
|
| 143 |
if __name__ == "__main__":
|
| 144 |
-
|
| 145 |
-
demo.launch(server_name="0.0.0.0", server_port=7860, css=css_path)
|
|
|
|
| 34 |
html += "</div>"
|
| 35 |
return html
|
| 36 |
|
| 37 |
+
def run_chat(query, api_key, search_method, use_reranker, legal_area, top_k, temperature):
|
| 38 |
if not query.strip():
|
| 39 |
return "Будь ласка, введіть запитання.", ""
|
| 40 |
|
|
|
|
| 45 |
"🧠 Семантичний": "semantic"
|
| 46 |
}
|
| 47 |
|
| 48 |
+
# Mapping Ukrainian legal area names to internal keys
|
| 49 |
+
area_map = {
|
| 50 |
+
"Всі": "Всі",
|
| 51 |
+
"Сімейне право": "сімейне_право",
|
| 52 |
+
"Трудове право": "трудове_право",
|
| 53 |
+
"Земельне право": "земельне_право",
|
| 54 |
+
"Цивільне право": "цивільне_право",
|
| 55 |
+
"Податкове право": "податкове_право",
|
| 56 |
+
"Кримінальне право": "кримінальне_право",
|
| 57 |
+
"Конституційне право": "конституційне_право",
|
| 58 |
+
"Адміністративне судочинство": "адміністративне_судочинство"
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
internal_method = method_map.get(search_method, "hybrid")
|
| 62 |
+
internal_area = area_map.get(legal_area, "Всі")
|
| 63 |
|
| 64 |
try:
|
| 65 |
answer, sources = rag_system.process_query(
|
|
|
|
| 68 |
use_reranker=use_reranker,
|
| 69 |
top_k_rerank=int(top_k),
|
| 70 |
temperature=float(temperature),
|
| 71 |
+
search_method=internal_method,
|
| 72 |
+
legal_area=internal_area
|
| 73 |
)
|
| 74 |
|
| 75 |
sources_html = format_sources(sources)
|
|
|
|
| 78 |
return f"Помилка при обробці запиту: {str(e)}", ""
|
| 79 |
|
| 80 |
# --- Gradio UI Construction ---
|
| 81 |
+
css_path = Path("assets/style.css")
|
| 82 |
+
with open(css_path, "r", encoding="utf-8") as f:
|
| 83 |
+
custom_css = f.read()
|
| 84 |
+
|
| 85 |
+
with gr.Blocks(title="Асистент із Законодавства", css=custom_css, theme=gr.themes.Soft()) as demo:
|
| 86 |
# Header
|
| 87 |
with gr.Row(elem_classes="header-container"):
|
| 88 |
with gr.Column(scale=0, min_width=80):
|
|
|
|
| 116 |
value="🔄 Гібридний (Рекомендовано)"
|
| 117 |
)
|
| 118 |
|
| 119 |
+
use_reranker = gr.Checkbox(label="Використовувати Reranker", value=True)
|
| 120 |
+
|
| 121 |
+
legal_area = gr.Dropdown(
|
| 122 |
+
label="Галузь права (Фільтр)",
|
| 123 |
+
choices=[
|
| 124 |
+
"Всі",
|
| 125 |
+
"Сімейне право",
|
| 126 |
+
"Трудове право",
|
| 127 |
+
"Земельне право",
|
| 128 |
+
"Цивільне право",
|
| 129 |
+
"Податкове право",
|
| 130 |
+
"Кримінальне право",
|
| 131 |
+
"Конституційне право",
|
| 132 |
+
"Адміністративне судочинство"
|
| 133 |
+
],
|
| 134 |
+
value="Всі"
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
with gr.Accordion("🛠️ Розширені параметри", open=False):
|
|
|
|
| 138 |
top_k = gr.Slider(label="Кількість джерел", minimum=1, maximum=20, step=1, value=config.DEFAULT_TOP_K_RERANK)
|
| 139 |
temperature = gr.Slider(label="Температура генерації", minimum=0.0, maximum=1.0, step=0.1, value=0.5)
|
| 140 |
|
|
|
|
| 165 |
btn = gr.Button(q, elem_classes="example-btn")
|
| 166 |
btn.click(lambda x=q: x, outputs=[query_input]).then(
|
| 167 |
fn=run_chat,
|
| 168 |
+
inputs=[query_input, api_key_input, search_method, use_reranker, legal_area, top_k, temperature],
|
| 169 |
outputs=[output_answer, output_sources]
|
| 170 |
)
|
| 171 |
|
| 172 |
# --- Interactions ---
|
| 173 |
submit_btn.click(
|
| 174 |
fn=run_chat,
|
| 175 |
+
inputs=[query_input, api_key_input, search_method, use_reranker, legal_area, top_k, temperature],
|
| 176 |
outputs=[output_answer, output_sources]
|
| 177 |
)
|
| 178 |
|
| 179 |
if __name__ == "__main__":
|
| 180 |
+
demo.launch(server_name="0.0.0.0", server_port=7860)
|
|
|
assets/style.css
CHANGED
|
@@ -69,18 +69,36 @@ body, .gradio-container {
|
|
| 69 |
}
|
| 70 |
|
| 71 |
/* Input Styling */
|
| 72 |
-
.gr-textbox textarea, .gr-textbox input
|
|
|
|
| 73 |
background-color: rgba(255, 255, 255, 0.05) !important;
|
| 74 |
border: 1px solid rgba(255, 255, 255, 0.1) !important;
|
| 75 |
color: white !important;
|
| 76 |
border-radius: 10px !important;
|
| 77 |
}
|
| 78 |
|
| 79 |
-
.gr-textbox textarea:focus, .gr-textbox input:focus
|
|
|
|
|
|
|
| 80 |
border-color: var(--blue) !important;
|
| 81 |
box-shadow: 0 0 0 2px rgba(62, 139, 247, 0.2) !important;
|
| 82 |
}
|
| 83 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
/* Examples Section */
|
| 85 |
.example-btn {
|
| 86 |
background: rgba(255, 255, 255, 0.05) !important;
|
|
|
|
| 69 |
}
|
| 70 |
|
| 71 |
/* Input Styling */
|
| 72 |
+
.gr-textbox textarea, .gr-textbox input,
|
| 73 |
+
input[type="text"], input[type="password"] {
|
| 74 |
background-color: rgba(255, 255, 255, 0.05) !important;
|
| 75 |
border: 1px solid rgba(255, 255, 255, 0.1) !important;
|
| 76 |
color: white !important;
|
| 77 |
border-radius: 10px !important;
|
| 78 |
}
|
| 79 |
|
| 80 |
+
.gr-textbox textarea:focus, .gr-textbox input:focus,
|
| 81 |
+
input[type="text"]:focus, input[type="password"]:focus {
|
| 82 |
+
background-color: rgba(255, 255, 255, 0.08) !important;
|
| 83 |
border-color: var(--blue) !important;
|
| 84 |
box-shadow: 0 0 0 2px rgba(62, 139, 247, 0.2) !important;
|
| 85 |
}
|
| 86 |
|
| 87 |
+
/* Autofill styling fix */
|
| 88 |
+
input:-webkit-autofill,
|
| 89 |
+
input:-webkit-autofill:hover,
|
| 90 |
+
input:-webkit-autofill:focus,
|
| 91 |
+
input:-webkit-autofill:active {
|
| 92 |
+
-webkit-box-shadow: 0 0 0 30px var(--bg-secondary) inset !important;
|
| 93 |
+
-webkit-text-fill-color: white !important;
|
| 94 |
+
transition: background-color 5000s ease-in-out 0s;
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
/* Ensure password and text inputs are readable always */
|
| 98 |
+
input[type="password"], input[type="text"], .gr-textbox input {
|
| 99 |
+
color: white !important;
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
/* Examples Section */
|
| 103 |
.example-btn {
|
| 104 |
background: rgba(255, 255, 255, 0.05) !important;
|
config.py
CHANGED
|
@@ -30,21 +30,22 @@ HYBRID_ALPHA = 0.3 # Semantic weight (higher = more semantic focus)
|
|
| 30 |
MIN_BM25_SCORE = 0.05 # Lower threshold to let good semantic hits through
|
| 31 |
|
| 32 |
# System Prompts
|
| 33 |
-
SYSTEM_PROMPT = """
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
1.
|
| 37 |
-
2.
|
| 38 |
-
3.
|
| 39 |
-
4.
|
| 40 |
-
5.
|
| 41 |
-
6.
|
| 42 |
-
7.
|
| 43 |
-
8.
|
| 44 |
-
9.
|
| 45 |
-
|
|
|
|
| 46 |
|
| 47 |
Context: {context}
|
| 48 |
|
| 49 |
-
|
| 50 |
"""
|
|
|
|
| 30 |
MIN_BM25_SCORE = 0.05 # Lower threshold to let good semantic hits through
|
| 31 |
|
| 32 |
# System Prompts
|
| 33 |
+
SYSTEM_PROMPT = """Ви — професійний юридичний асистент, що спеціалізується на законодавстві України. Ваше завдання — надати максимально корисну відповідь на основі наданих фрагментів документів (Контекст).
|
| 34 |
+
|
| 35 |
+
ОБОВ'ЯЗКОВІ ПРАВИЛА:
|
| 36 |
+
1. ЗАВЖДИ ретельно аналізуйте весь наданий контекст.
|
| 37 |
+
2. Якщо в контексті є пряма відповідь — надайте її чітко та структуровано.
|
| 38 |
+
3. Якщо інформація часткова — поясніть, що саме відомо з документів.
|
| 39 |
+
4. ЗАВЖДИ вказуйте джерела: вставляйте номери посилань у квадратних дужках [1], [2], [3] безпосередньо після тверджень, які вони підтверджують.
|
| 40 |
+
5. Якщо в контексті НЕМАЄ інформації, скажіть: "На жаль, у наданих документах немає інформації для відповіді на це запитання."
|
| 41 |
+
6. Відповідайте ТІЛЬКИ українською мовою.
|
| 42 |
+
7. НІКОЛИ не вигадуйте статті або факти, яких немає в контексті.
|
| 43 |
+
8. При цитуванні норм вказуйте номери статей та назви кодексів/законів.
|
| 44 |
+
9. Використовуйте списки та заголовки для кращої структури.
|
| 45 |
+
|
| 46 |
+
Приклад цитування: Згідно зі статтею 115 ККУ, вбивство — це умисне протиправне заподіяння смерті іншій людині [1]. За це передбачено покарання у вигляді позбавлення волі на строк від семи до п'ятнадцяти років [2].
|
| 47 |
|
| 48 |
Context: {context}
|
| 49 |
|
| 50 |
+
Пам'ятайте: відповідайте ТІЛЬКИ українською, будьте точними, цитуйте джерела для кожного твердження.
|
| 51 |
"""
|
modules/rag_system.py
CHANGED
|
@@ -27,7 +27,8 @@ class RAGSystem:
|
|
| 27 |
top_k_retrieval: int = config.DEFAULT_TOP_K_RETRIEVAL,
|
| 28 |
top_k_rerank: int = config.DEFAULT_TOP_K_RERANK,
|
| 29 |
temperature: float = config.DEFAULT_TEMPERATURE,
|
| 30 |
-
search_method: str = 'hybrid'
|
|
|
|
| 31 |
) -> Tuple[str, List[Dict]]:
|
| 32 |
"""
|
| 33 |
Main RAG pipeline:
|
|
@@ -43,8 +44,8 @@ class RAGSystem:
|
|
| 43 |
return "Будь ласка, введіть API ключ (Groq) для продовження.", []
|
| 44 |
|
| 45 |
# 1. Retrieval
|
| 46 |
-
print(f"Retrieving for: {query} (method: {search_method})")
|
| 47 |
-
retrieved_chunks = self.retriever.search(query, top_k=top_k_retrieval, method=search_method)
|
| 48 |
|
| 49 |
# 2. Reranking
|
| 50 |
if use_reranker and retrieved_chunks:
|
|
|
|
| 27 |
top_k_retrieval: int = config.DEFAULT_TOP_K_RETRIEVAL,
|
| 28 |
top_k_rerank: int = config.DEFAULT_TOP_K_RERANK,
|
| 29 |
temperature: float = config.DEFAULT_TEMPERATURE,
|
| 30 |
+
search_method: str = 'hybrid',
|
| 31 |
+
legal_area: str = None
|
| 32 |
) -> Tuple[str, List[Dict]]:
|
| 33 |
"""
|
| 34 |
Main RAG pipeline:
|
|
|
|
| 44 |
return "Будь ласка, введіть API ключ (Groq) для продовження.", []
|
| 45 |
|
| 46 |
# 1. Retrieval
|
| 47 |
+
print(f"Retrieving for: {query} (method: {search_method}, legal_area: {legal_area})")
|
| 48 |
+
retrieved_chunks = self.retriever.search(query, top_k=top_k_retrieval, method=search_method, legal_area=legal_area)
|
| 49 |
|
| 50 |
# 2. Reranking
|
| 51 |
if use_reranker and retrieved_chunks:
|
modules/retriever.py
CHANGED
|
@@ -285,7 +285,8 @@ class Retriever:
|
|
| 285 |
query: str,
|
| 286 |
top_k: int = 30,
|
| 287 |
method: str = 'hybrid',
|
| 288 |
-
alpha: float = None # Uses config.HYBRID_ALPHA if None
|
|
|
|
| 289 |
) -> List[Dict]:
|
| 290 |
"""
|
| 291 |
Search for relevant chunks using specified method.
|
|
@@ -295,6 +296,7 @@ class Retriever:
|
|
| 295 |
top_k: Number of results to return
|
| 296 |
method: 'bm25', 'semantic', or 'hybrid'
|
| 297 |
alpha: Weight for hybrid search (semantic weight, 0-1)
|
|
|
|
| 298 |
|
| 299 |
Returns:
|
| 300 |
List of result dictionaries with chunk and score
|
|
@@ -302,6 +304,16 @@ class Retriever:
|
|
| 302 |
# Ensure top_k is an integer (handle Gradio sliders passing floats)
|
| 303 |
top_k = int(top_k)
|
| 304 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 305 |
method_map = {
|
| 306 |
'bm25': self._search_bm25,
|
| 307 |
'semantic': self._search_semantic,
|
|
@@ -314,27 +326,50 @@ class Retriever:
|
|
| 314 |
search_func = method_map[method]
|
| 315 |
if method == 'hybrid':
|
| 316 |
effective_alpha = alpha if alpha is not None else config.HYBRID_ALPHA
|
| 317 |
-
return search_func(query, top_k, effective_alpha)
|
| 318 |
-
return search_func(query, top_k)
|
| 319 |
|
| 320 |
-
def _search_bm25(self, query: str, top_k: int) -> List[Dict]:
|
| 321 |
"""Keyword-based BM25 search with spell correction."""
|
| 322 |
tokenized_query = self._tokenize_and_lemmatize(query)
|
| 323 |
scores = self.bm25.get_scores(tokenized_query)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 324 |
top_indices = np.argsort(scores)[::-1][:top_k]
|
| 325 |
|
|
|
|
| 326 |
return [
|
| 327 |
{
|
| 328 |
'chunk': self.chunks[idx],
|
| 329 |
'score': float(scores[idx]),
|
| 330 |
'method': 'bm25'
|
| 331 |
}
|
| 332 |
-
for idx in top_indices
|
| 333 |
]
|
| 334 |
|
| 335 |
-
def _search_semantic(self, query: str, top_k: int) -> List[Dict]:
|
| 336 |
"""Semantic similarity search using embeddings."""
|
| 337 |
query_embedding = self.model.encode(query, convert_to_tensor=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 338 |
hits = util.semantic_search(query_embedding, self.embeddings, top_k=top_k)[0]
|
| 339 |
|
| 340 |
return [
|
|
@@ -346,7 +381,7 @@ class Retriever:
|
|
| 346 |
for hit in hits
|
| 347 |
]
|
| 348 |
|
| 349 |
-
def _search_hybrid(self, query: str, top_k: int, alpha: float) -> List[Dict]:
|
| 350 |
"""
|
| 351 |
Hybrid search combining BM25 and semantic similarity.
|
| 352 |
|
|
@@ -360,6 +395,13 @@ class Retriever:
|
|
| 360 |
query_embedding = self.model.encode(query, convert_to_tensor=True)
|
| 361 |
semantic_scores = util.cos_sim(query_embedding, self.embeddings)[0].cpu().numpy()
|
| 362 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 363 |
# Normalize scores to [0, 1]
|
| 364 |
bm25_norm = self._min_max_normalize(bm25_scores)
|
| 365 |
semantic_norm = self._min_max_normalize(semantic_scores)
|
|
@@ -367,6 +409,12 @@ class Retriever:
|
|
| 367 |
# Combine with weighted sum
|
| 368 |
combined_scores = alpha * semantic_norm + (1 - alpha) * bm25_norm
|
| 369 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 370 |
# Apply BM25 threshold: penalize chunks with no keyword overlap
|
| 371 |
bm25_max = np.max(bm25_scores) if np.max(bm25_scores) > 0 else 1.0
|
| 372 |
bm25_relative = bm25_scores / bm25_max
|
|
@@ -384,7 +432,7 @@ class Retriever:
|
|
| 384 |
'semantic_score': float(semantic_scores[idx]),
|
| 385 |
'method': 'hybrid'
|
| 386 |
}
|
| 387 |
-
for idx in top_indices
|
| 388 |
]
|
| 389 |
|
| 390 |
@staticmethod
|
|
|
|
| 285 |
query: str,
|
| 286 |
top_k: int = 30,
|
| 287 |
method: str = 'hybrid',
|
| 288 |
+
alpha: float = None, # Uses config.HYBRID_ALPHA if None
|
| 289 |
+
legal_area: str = None
|
| 290 |
) -> List[Dict]:
|
| 291 |
"""
|
| 292 |
Search for relevant chunks using specified method.
|
|
|
|
| 296 |
top_k: Number of results to return
|
| 297 |
method: 'bm25', 'semantic', or 'hybrid'
|
| 298 |
alpha: Weight for hybrid search (semantic weight, 0-1)
|
| 299 |
+
legal_area: Optional filter by metadata['legal_area']
|
| 300 |
|
| 301 |
Returns:
|
| 302 |
List of result dictionaries with chunk and score
|
|
|
|
| 304 |
# Ensure top_k is an integer (handle Gradio sliders passing floats)
|
| 305 |
top_k = int(top_k)
|
| 306 |
|
| 307 |
+
# Pre-filter chunks by legal_area if provided
|
| 308 |
+
filtered_indices = None
|
| 309 |
+
if legal_area and legal_area != "Всі":
|
| 310 |
+
filtered_indices = [
|
| 311 |
+
i for i, chunk in enumerate(self.chunks)
|
| 312 |
+
if chunk.get('metadata', {}).get('legal_area') == legal_area
|
| 313 |
+
]
|
| 314 |
+
if not filtered_indices:
|
| 315 |
+
return []
|
| 316 |
+
|
| 317 |
method_map = {
|
| 318 |
'bm25': self._search_bm25,
|
| 319 |
'semantic': self._search_semantic,
|
|
|
|
| 326 |
search_func = method_map[method]
|
| 327 |
if method == 'hybrid':
|
| 328 |
effective_alpha = alpha if alpha is not None else config.HYBRID_ALPHA
|
| 329 |
+
return search_func(query, top_k, effective_alpha, filtered_indices=filtered_indices)
|
| 330 |
+
return search_func(query, top_k, filtered_indices=filtered_indices)
|
| 331 |
|
| 332 |
+
def _search_bm25(self, query: str, top_k: int, filtered_indices: List[int] = None) -> List[Dict]:
|
| 333 |
"""Keyword-based BM25 search with spell correction."""
|
| 334 |
tokenized_query = self._tokenize_and_lemmatize(query)
|
| 335 |
scores = self.bm25.get_scores(tokenized_query)
|
| 336 |
+
|
| 337 |
+
if filtered_indices is not None:
|
| 338 |
+
# Mask scores for non-filtered chunks
|
| 339 |
+
mask = np.zeros(len(scores), dtype=bool)
|
| 340 |
+
mask[filtered_indices] = True
|
| 341 |
+
scores[~mask] = -1e9
|
| 342 |
+
|
| 343 |
top_indices = np.argsort(scores)[::-1][:top_k]
|
| 344 |
|
| 345 |
+
# Filter out masked scores from results
|
| 346 |
return [
|
| 347 |
{
|
| 348 |
'chunk': self.chunks[idx],
|
| 349 |
'score': float(scores[idx]),
|
| 350 |
'method': 'bm25'
|
| 351 |
}
|
| 352 |
+
for idx in top_indices if scores[idx] > -1e8
|
| 353 |
]
|
| 354 |
|
| 355 |
+
def _search_semantic(self, query: str, top_k: int, filtered_indices: List[int] = None) -> List[Dict]:
|
| 356 |
"""Semantic similarity search using embeddings."""
|
| 357 |
query_embedding = self.model.encode(query, convert_to_tensor=True)
|
| 358 |
+
|
| 359 |
+
if filtered_indices is not None:
|
| 360 |
+
# Filter embeddings before searching
|
| 361 |
+
filtered_embeddings = self.embeddings[filtered_indices]
|
| 362 |
+
hits = util.semantic_search(query_embedding, filtered_embeddings, top_k=top_k)[0]
|
| 363 |
+
|
| 364 |
+
return [
|
| 365 |
+
{
|
| 366 |
+
'chunk': self.chunks[filtered_indices[hit['corpus_id']]],
|
| 367 |
+
'score': float(hit['score']),
|
| 368 |
+
'method': 'semantic'
|
| 369 |
+
}
|
| 370 |
+
for hit in hits
|
| 371 |
+
]
|
| 372 |
+
|
| 373 |
hits = util.semantic_search(query_embedding, self.embeddings, top_k=top_k)[0]
|
| 374 |
|
| 375 |
return [
|
|
|
|
| 381 |
for hit in hits
|
| 382 |
]
|
| 383 |
|
| 384 |
+
def _search_hybrid(self, query: str, top_k: int, alpha: float, filtered_indices: List[int] = None) -> List[Dict]:
|
| 385 |
"""
|
| 386 |
Hybrid search combining BM25 and semantic similarity.
|
| 387 |
|
|
|
|
| 395 |
query_embedding = self.model.encode(query, convert_to_tensor=True)
|
| 396 |
semantic_scores = util.cos_sim(query_embedding, self.embeddings)[0].cpu().numpy()
|
| 397 |
|
| 398 |
+
# Mask scores if filtered_indices is provided
|
| 399 |
+
if filtered_indices is not None:
|
| 400 |
+
mask = np.zeros(len(self.chunks), dtype=bool)
|
| 401 |
+
mask[filtered_indices] = True
|
| 402 |
+
bm25_scores[~mask] = 0.0 # BM25 min is typically 0
|
| 403 |
+
semantic_scores[~mask] = -1.0 # Cosine min is -1
|
| 404 |
+
|
| 405 |
# Normalize scores to [0, 1]
|
| 406 |
bm25_norm = self._min_max_normalize(bm25_scores)
|
| 407 |
semantic_norm = self._min_max_normalize(semantic_scores)
|
|
|
|
| 409 |
# Combine with weighted sum
|
| 410 |
combined_scores = alpha * semantic_norm + (1 - alpha) * bm25_norm
|
| 411 |
|
| 412 |
+
# Re-apply mask after normalization just in case
|
| 413 |
+
if filtered_indices is not None:
|
| 414 |
+
mask = np.zeros(len(self.chunks), dtype=bool)
|
| 415 |
+
mask[filtered_indices] = True
|
| 416 |
+
combined_scores[~mask] = -1.0
|
| 417 |
+
|
| 418 |
# Apply BM25 threshold: penalize chunks with no keyword overlap
|
| 419 |
bm25_max = np.max(bm25_scores) if np.max(bm25_scores) > 0 else 1.0
|
| 420 |
bm25_relative = bm25_scores / bm25_max
|
|
|
|
| 432 |
'semantic_score': float(semantic_scores[idx]),
|
| 433 |
'method': 'hybrid'
|
| 434 |
}
|
| 435 |
+
for idx in top_indices if combined_scores[idx] > -0.9
|
| 436 |
]
|
| 437 |
|
| 438 |
@staticmethod
|