Spaces:
Running
Running
| import os | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| import base64 | |
| from io import BytesIO | |
| from typing import Optional | |
| import time | |
| import gradio as gr | |
| from PIL import Image | |
| from olmocr.data.renderpdf import render_pdf_to_base64png | |
| from openai_backend import _run_openai_vision | |
| from common import MODELS_MAP, MODEL_GEMINI, MODEL_OLMOCR | |
| from gemini_backend import _run_gemini_vision | |
| from logging_helper import log as _log, log_debug as _log_debug, get_latest_model_log as _get_latest_model_log | |
| from olm_ocr import _run_olmocr | |
| APP_TITLE = "words2doc" | |
| APP_DESCRIPTION = "Upload a PDF or image with (handwritten) text and convert it to CSV using different LLM backends." | |
| # -------- Utility helpers -------- # | |
| def _load_image_from_upload(path: str) -> Image.Image: | |
| """Load an image from a path (for image uploads).""" | |
| return Image.open(path).convert("RGB") | |
| def _pdf_to_pil_image(path: str, page: int = 1, target_longest_image_dim: int = 1288) -> Image.Image: | |
| """Render a single PDF page to PIL Image via olmocr's helper.""" | |
| image_base64 = render_pdf_to_base64png(path, page, target_longest_image_dim=target_longest_image_dim) | |
| return Image.open(BytesIO(base64.b64decode(image_base64))) | |
| def _image_from_any_file(file_path: str) -> Image.Image: | |
| """Accept either PDF or image and always return a PIL Image (first page for PDFs).""" | |
| lower = file_path.lower() | |
| if lower.endswith(".pdf"): | |
| return _pdf_to_pil_image(file_path) | |
| return _load_image_from_upload(file_path) | |
| def _convert_to_grayscale(image: Image.Image) -> Image.Image: | |
| """Convert an image to grayscale.""" | |
| return image.convert("L") | |
| def _downscale_image(image: Image.Image, target_width: int = 1024) -> Image.Image: | |
| """Downscale image to target width, preserving aspect ratio.""" | |
| if image.width > target_width: | |
| ratio = target_width / float(image.width) | |
| new_height = int(float(image.height) * ratio) | |
| return image.resize((target_width, new_height), Image.Resampling.LANCZOS) | |
| return image | |
| def _write_csv_to_temp_file(csv_text: str) -> str: | |
| """Write CSV text to a temporary file and return the path.""" | |
| import tempfile | |
| fd, path = tempfile.mkstemp(suffix=".csv", prefix="words2doc_") | |
| with os.fdopen(fd, "w", encoding="utf-8") as f: | |
| f.write(csv_text) | |
| return path | |
| # -------- Backends -------- # | |
| # Function to encode the image | |
| def _encode_image(image_path): | |
| with open(image_path, "rb") as image_file: | |
| return base64.b64encode(image_file.read()).decode("utf-8") | |
| # -------- Main processing function -------- # | |
| def process_document(file_obj, model_choice: str, prompt: str): | |
| if file_obj is None: | |
| return "No file uploaded.", None, "" | |
| file_path = getattr(file_obj, "name", None) or file_obj | |
| image = _image_from_any_file(file_path) | |
| _log("Converting image to grayscale") | |
| image = _convert_to_grayscale(image) | |
| _log("Downscaling image to 1024 width") | |
| image = _downscale_image(image, 1024) | |
| if not prompt.strip(): | |
| prompt = ( | |
| "You are an OCR-to-CSV assistant. Read the table or structured text in the image and output a valid " | |
| "CSV representation. Use commas as separators and include a header row if appropriate." | |
| ) | |
| _log_debug(f"Using model: {model_choice}") | |
| if MODELS_MAP[model_choice]["backend"] == "openai": | |
| csv_text = _run_openai_vision(image, prompt, model_choice) | |
| elif MODELS_MAP[model_choice]["backend"] == "gemini": | |
| csv_text = _run_gemini_vision(image, prompt, model_choice) | |
| elif MODELS_MAP[model_choice]["backend"] == "olmocr": | |
| csv_text = _run_olmocr(image, prompt) | |
| else: | |
| csv_text = f"Unknown model choice: {model_choice}" | |
| csv_file_path = _write_csv_to_temp_file(csv_text) | |
| latest_log = _get_latest_model_log() or "" | |
| return csv_text, csv_file_path, latest_log | |
| # -------- Gradio UI -------- # | |
| def build_interface() -> gr.Blocks: | |
| with gr.Blocks(title=APP_TITLE) as demo: | |
| gr.Markdown(f"# {APP_TITLE}") | |
| gr.Markdown(APP_DESCRIPTION) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| file_input = gr.File( | |
| label="Upload PDF or image", | |
| file_types=[".pdf", ".png", ".jpg", ".jpeg", ".webp"], | |
| ) | |
| image_example_preview = gr.Image( | |
| label="Example image preview", | |
| value="static/vocab.jpg", # adjust to real relative path | |
| interactive=False, | |
| ) | |
| gr.Examples( | |
| examples=[["vocab.jpg", "vocab.jpg"]], | |
| inputs=[file_input, image_example_preview], | |
| label="Example image", | |
| ) | |
| model_selector = gr.Dropdown( | |
| label="LLM backend", | |
| choices=list(MODELS_MAP.keys()), | |
| value=MODEL_GEMINI, | |
| ) | |
| prompt_editor = gr.Textbox( | |
| label="Prompt editor", | |
| value=( | |
| "You are an OCR and vocabulary extractor.\n" | |
| "You are given a photo of a vocabulary book page with words in original language and their translations.\n" | |
| "Your task:\n" | |
| "- Read the text on the page.\n" | |
| "- First, detect the language of the words.\n" | |
| "- Identify all words and their corresponding translations.\n" | |
| "- Do NOT include dates, page numbers, headings, or example sentences.\n" | |
| "- Do NOT repeat the same word twice.\n" | |
| "- If there are duplicates, keep only one row.\n" | |
| "\n" | |
| "Output format (VERY IMPORTANT):\n" | |
| "- Output ONLY CSV rows.\n" | |
| "- NO explanations, NO extra text, NO quotes.\n" | |
| "- Each line must be: <word>,<translation>\n" | |
| "- Use a comma as separator.\n" | |
| "- No header row.\n" | |
| "- Example:\n" | |
| "word1,translation1\n" | |
| "word2,translation2\n" | |
| "word3,translation3\n" | |
| "\n" | |
| "Now output ONLY the CSV rows for the attached image." | |
| ), | |
| lines=6, | |
| placeholder=( | |
| "Describe how the CSV should be structured. If left empty, a default OCR-to-CSV prompt is used." | |
| ), | |
| ) | |
| run_button = gr.Button("Run", variant="primary") | |
| with gr.Column(scale=1): | |
| csv_output = gr.Textbox( | |
| label="CSV output (preview)", | |
| lines=20, | |
| buttons=["copy"], | |
| ) | |
| csv_file = gr.File(label="Download CSV file", interactive=False) | |
| with gr.Row(): | |
| logs_output = gr.Textbox( | |
| label="Logs", | |
| lines=4, | |
| ) | |
| run_button.click( | |
| fn=process_document, | |
| inputs=[file_input, model_selector, prompt_editor], | |
| outputs=[csv_output, csv_file, logs_output], | |
| ) | |
| return demo | |
| demo = build_interface() | |
| if __name__ == "__main__": | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=int(os.getenv("PORT", "7860")), | |
| share=True, | |
| ) | |