from __future__ import annotations import os import json import logging import time from models import OptimizeRequest, QARequest, AutotuneRequest from fastapi import FastAPI, HTTPException, UploadFile, File, Form from fastapi.middleware.cors import CORSMiddleware import uvicorn import shutil try: from ragmint.autotuner import AutoRAGTuner from ragmint.qa_generator import generate_validation_qa from ragmint.explainer import explain_results from ragmint.leaderboard import Leaderboard from ragmint.tuner import RAGMint except Exception as e: AutoRAGTuner = None generate_validation_qa = None explain_results = None Leaderboard = None RAGMint = None _import_error = e else: _import_error = None from dotenv import load_dotenv load_dotenv() # Logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger("ragmint_mcp_server") # FastAPI app = FastAPI(title="Ragmint MCP Server", version="0.1.0") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) DEFAULT_DATA_DIR = "../data/docs" LEADERBOARD_STORAGE = "experiments/leaderboard.jsonl" os.makedirs("../experiments", exist_ok=True) @app.get("/health") def health(): return { "status": "ok", "ragmint_imported": _import_error is None, "import_error": str(_import_error) if _import_error else None, } @app.post("/upload_docs") async def upload_docs( docs_path: str = Form(...), # User specifies folder files: list[UploadFile] = File(...) ): os.makedirs(docs_path, exist_ok=True) # create folder if missing saved_files = [] for file in files: file_path = os.path.join(docs_path, file.filename) with open(file_path, "wb") as f: shutil.copyfileobj(file.file, f) saved_files.append(file.filename) return {"status": "ok", "uploaded_files": saved_files, "docs_path": docs_path} @app.post("/optimize_rag") def optimize_rag(req: OptimizeRequest): logger.info("Received optimize_rag request: %s", req.json()) if RAGMint is None: raise HTTPException( status_code=500, detail=f"Ragmint imports failed or RAGMint unavailable: {_import_error}" ) docs_path = req.docs_path or DEFAULT_DATA_DIR if not os.path.isdir(docs_path): raise HTTPException(status_code=400, detail=f"docs_path does not exist: {docs_path}") try: # Build RAGMint exactly from request rag = RAGMint( docs_path=docs_path, retrievers=req.retriever, embeddings=req.embedding_model, rerankers=(req.rerankers or ["mmr"]), chunk_sizes=req.chunk_sizes, overlaps=req.overlaps, strategies=req.strategy, ) # Validation selection validation_set = None validation_choice = (req.validation_choice or "").strip() default_val_path = os.path.join(docs_path, "validation_qa.json") # Auto if not validation_choice: if os.path.exists(default_val_path): validation_set = default_val_path logger.info("Using default validation set: %s", validation_set) else: logger.warning("No validation_choice provided and no default found.") validation_set = None # Remote HF dataset elif "/" in validation_choice and not os.path.exists(validation_choice): validation_set = validation_choice logger.info("Using Hugging Face validation dataset: %s", validation_set) # Local file elif os.path.exists(validation_choice): validation_set = validation_choice logger.info("Using local validation dataset: %s", validation_set) # Generate elif validation_choice.lower() == "generate": try: gen_path = os.path.join(docs_path, "validation_qa.json") generate_validation_qa( docs_path=docs_path, output_path=gen_path, llm_model=req.llm_model if hasattr(req, "llm_model") else "gemini-2.5-flash-lite" ) validation_set = gen_path logger.info("Generated new validation QA set at: %s", validation_set) except Exception as e: logger.exception("Failed to generate validation QA dataset: %s", e) raise HTTPException(status_code=500, detail=f"Failed to generate validation QA dataset: {e}") # Optimize start_time = time.time() best, results = rag.optimize( validation_set=validation_set, metric=req.metric, trials=req.trials, search_type=req.search_type ) elapsed = time.time() - start_time run_id = f"opt_{int(time.time())}" # Corpus stats try: corpus_stats = { "num_docs": len(rag.documents), "avg_len": sum(len(d.split()) for d in rag.documents) / max(1, len(rag.documents)), "corpus_size": sum(len(d) for d in rag.documents), } except Exception: corpus_stats = None # Leaderboard try: if Leaderboard: lb = Leaderboard() lb.upload( run_id=run_id, best_config=best, best_score=best.get("faithfulness", best.get("score", 0.0)), all_results=results, documents=os.listdir(docs_path), model=best.get("embedding_model", req.embedding_model), corpus_stats=corpus_stats, ) except Exception: logger.exception("Leaderboard persistence failed for optimize_rag") return { "status": "finished", "run_id": run_id, "elapsed_seconds": elapsed, "best_config": best, "results": results, "corpus_stats": corpus_stats, } except Exception as exc: logger.exception("optimize_rag failed") raise HTTPException(status_code=500, detail=str(exc)) @app.post("/autotune_rag") def autotune_rag(req: AutotuneRequest): logger.info("Received autotune_rag request: %s", req.json()) if AutoRAGTuner is None or RAGMint is None: raise HTTPException( status_code=500, detail=f"Ragmint autotuner/RAGMint imports failed: {_import_error}" ) docs_path = req.docs_path or DEFAULT_DATA_DIR if not os.path.isdir(docs_path): raise HTTPException(status_code=400, detail=f"docs_path does not exist: {docs_path}") try: start_time = time.time() tuner = AutoRAGTuner(docs_path=docs_path) rec = tuner.recommend( embedding_model=req.embedding_model, num_chunk_pairs=req.num_chunk_pairs ) chunk_candidates = tuner.suggest_chunk_sizes( model_name=rec.get("embedding_model"), num_pairs=int(req.num_chunk_pairs), step=20 ) chunk_sizes = sorted({c for c, _ in chunk_candidates}) overlaps = sorted({o for _, o in chunk_candidates}) rag = RAGMint( docs_path=docs_path, retrievers=[rec["retriever"]], embeddings=[rec["embedding_model"]], rerankers=["mmr"], chunk_sizes=chunk_sizes, overlaps=overlaps, strategies=[rec["strategy"]], ) # Validation selection validation_set = None validation_choice = (req.validation_choice or "").strip() default_val_path = os.path.join(docs_path, "validation_qa.jsonl") if not validation_choice: if os.path.exists(default_val_path): validation_set = default_val_path logger.info("Using default validation set: %s", validation_set) else: logger.warning("No validation_choice provided and no default found.") validation_set = None elif "/" in validation_choice and not os.path.exists(validation_choice): validation_set = validation_choice elif os.path.exists(validation_choice): validation_set = validation_choice elif validation_choice.lower() == "generate": try: gen_path = os.path.join(docs_path, "validation_qa.json") generate_validation_qa( docs_path=docs_path, output_path=gen_path, llm_model=req.llm_model if hasattr(req, "llm_model") else "gemini-2.5-flash-lite", ) validation_set = gen_path except Exception as e: logger.exception("Failed to generate validation QA dataset: %s", e) raise HTTPException(status_code=500, detail=f"Failed to generate validation QA dataset: {e}") # Full optimize best, results = rag.optimize( validation_set=validation_set, metric=req.metric, search_type=req.search_type, trials=req.trials, ) elapsed = time.time() - start_time run_id = f"autotune_{int(time.time())}" # Corpus stats try: corpus_stats = { "num_docs": len(rag.documents), "avg_len": sum(len(d.split()) for d in rag.documents) / max(1, len(rag.documents)), "corpus_size": sum(len(d) for d in rag.documents), } except Exception: corpus_stats = None # Leaderboard try: if Leaderboard: lb = Leaderboard() lb.upload( run_id=run_id, best_config=best, best_score=best.get("faithfulness", best.get("score", 0.0)), all_results=results, documents=os.listdir(docs_path), model=best.get("embedding_model", rec.get("embedding_model")), corpus_stats=corpus_stats, ) except Exception: logger.exception("Leaderboard persistence failed for autotune_rag") return { "status": "finished", "run_id": run_id, "elapsed_seconds": elapsed, "recommendation": rec, "chunk_candidates": chunk_candidates, "best_config": best, "results": results, "corpus_stats": corpus_stats, } except Exception as exc: logger.exception("autotune_rag failed") raise HTTPException(status_code=500, detail=str(exc)) @app.post("/generate_validation_qa") def generate_qa(req: QARequest): logger.info("Received generate_validation_qa request: %s", req.json()) if generate_validation_qa is None: raise HTTPException(status_code=500, detail=f"Ragmint imports failed: {_import_error}") try: out_path = f"data/docs/validation_qa.json" os.makedirs(os.path.dirname(out_path), exist_ok=True) generate_validation_qa( docs_path=req.docs_path, output_path=out_path, llm_model=req.llm_model, batch_size=req.batch_size, min_q=req.min_q, max_q=req.max_q, ) with open(out_path, "r", encoding="utf-8") as f: data = json.load(f) return { "status": "finished", "output_path": out_path, "preview_count": len(data), "sample": data[:5], } except Exception as exc: logger.exception("generate_validation_qa failed") raise HTTPException(status_code=500, detail=str(exc)) # ----------------------- # FastAPI launch # ----------------------- def main(): uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info") if __name__ == "__main__": main()