Spaces:
Paused
Paused
| import fitz | |
| import re | |
| import chromadb | |
| from chromadb.utils import embedding_functions | |
| import uuid | |
| import torch | |
| from langchain.text_splitter import SentenceTransformersTokenTextSplitter | |
| from sentence_transformers import CrossEncoder | |
| emb_model_name = "sentence-transformers/all-mpnet-base-v2" | |
| sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="all-mpnet-base-v2") | |
| cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2') | |
| client = chromadb.PersistentClient(path='.vectorstore') | |
| collection = client.get_or_create_collection(name='huerto',embedding_function=sentence_transformer_ef,metadata={"hnsw:space": "cosine"}) | |
| def parse_pdf(file) : | |
| '''transforma un pdf en una lista''' | |
| pdf = fitz.open(file) | |
| output = [] | |
| for page_num in range(pdf.page_count): | |
| page = pdf[page_num] | |
| text = page.get_text() | |
| # Merge hyphenated words | |
| text = re.sub(r"(\w+)-\n(\w+)", r"\1\2", text) | |
| # Fix newlines in the middle of sentences | |
| text = re.sub(r"(?<!\n\s)\n(?!\s\n)", " ", text.strip()) | |
| # Remove multiple newlines | |
| text = re.sub(r"\n\s*\n", "\n\n", text) | |
| output.append(text) | |
| return output | |
| def file_to_splits(file,tokens_per_chunk,chunk_overlap ): | |
| '''Transforma un txt o pdf en una en una lista que contiene piezas con metadata''' | |
| text_splitter = SentenceTransformersTokenTextSplitter( | |
| model_name=emb_model_name, | |
| tokens_per_chunk=tokens_per_chunk, | |
| chunk_overlap=chunk_overlap, | |
| ) | |
| text = parse_pdf(file) | |
| doc_chunks = [] | |
| for i in range(len(text)): | |
| chunks = text_splitter.split_text(text[i]) | |
| for j in range(len(chunks)): | |
| doc = [chunks[j], {"source": file.split('/')[-1] ,"page": i+1, "chunk": j+1}, str(uuid.uuid4())] | |
| doc_chunks.append(doc) | |
| return doc_chunks | |
| def file_to_vs(file,tokens_per_chunk, chunk_overlap): | |
| try: | |
| splits=[] | |
| splits.extend(file_to_splits(file, | |
| tokens_per_chunk, | |
| chunk_overlap)) | |
| splits = list(zip(*splits)) | |
| collection.add(documents=list(splits[0]), metadatas=list(splits[1]), ids= list(splits[2])) | |
| return 'Files uploaded successfully' | |
| except Exception as e: | |
| return str(e) | |
| def similarity_search(query,k): | |
| sources = {} | |
| ss_out= collection.query(query_texts=[query],n_results=20) | |
| for _ in range(len(ss_out['ids'][0])): | |
| score = float(cross_encoder.predict([query,ss_out['documents'][0][_]],activation_fct=torch.nn.Sigmoid())) | |
| sources[str(_)]={"page_content":ss_out['documents'][0][_],"metadata":ss_out['metadatas'][0][_],"similarity":round(score*100,2)} | |
| sorted_sources = sorted(sources.items(), key=lambda x: x[1]['similarity'], reverse=True) | |
| sources = {} | |
| for _ in range(k): | |
| sources[str(_)] = sorted_sources[_][1] | |
| return sources | |