Skip to content

Instantly share code, notes, and snippets.

@pszemraj
Last active April 7, 2025 22:51
Show Gist options
  • Save pszemraj/c3b7a39c78acf0aba974744920d741d6 to your computer and use it in GitHub Desktop.
Save pszemraj/c3b7a39c78acf0aba974744920d741d6 to your computer and use it in GitHub Desktop.
Standalone Asynchronous RolmOCR Inference Script using vLLM and PyMuPDF.
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Standalone Asynchronous RolmOCR Inference Script using vLLM and PyMuPDF.
This script processes PDF files from an input directory using the
reducto/RolmOCR model served locally by vLLM via its OpenAI-compatible API.
It renders each page, sends API requests concurrently for OCR, extracts plain
text, and saves the combined text for each PDF into a corresponding .txt file
in the specified output directory.
This version uses asyncio and the AsyncOpenAI client to significantly speed up
processing by sending multiple page OCR requests to the vLLM server concurrently.
**IMPORTANT:** Requires a separate vLLM server running with the RolmOCR model.
Start the server BEFORE running this script, for example:
vllm serve reducto/RolmOCR --max-num-seqs 256 --gpu-memory-utilization 0.9
Dependencies (vLLM - see vLLM docs for specific CUDA versions):
pip install ninja vllm flash-attn
Dependencies (Script):
pip install "openai>=1.0" PyMuPDF Pillow fire tqdm pypdf mdformat "tqdm[asyncio]" joblib
Example Usage:
# 1. Start the vLLM server in a separate terminal:
# vllm serve reducto/RolmOCR
# 2. Run this script:
python async_pipeline.py \
--input_dir ./my_pdfs \
--output_dir ./output_text \
--model_id reducto/RolmOCR \
--max_pages 100 \
--overwrite \
--api_base_url http://localhost:8000/v1 \
--concurrency_limit 16
"""
import asyncio
import base64
import io
import logging
import os
import re
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import fire
import mdformat
from joblib import Parallel, delayed
from PIL import Image
from pypdf import PdfReader
from pypdf.errors import PdfReadError
from tqdm import tqdm
from tqdm.asyncio import tqdm_asyncio
try:
from openai import APIConnectionError, APIStatusError, AsyncOpenAI, RateLimitError
except ImportError:
print("=" * 80)
print("ERROR: openai library >= 1.0 not found.")
print("Please install it: pip install 'openai>=1.0'")
print("=" * 80)
exit(1)
try:
import fitz # PyMuPDF
except ImportError:
print("=" * 80)
print("ERROR: PyMuPDF library not found.")
print("Please install it: pip install PyMuPDF")
print("=" * 80)
exit(1)
# --- Configuration ---
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - [%(funcName)s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
logger = logging.getLogger(__name__)
# Reduce noise from underlying libraries
logging.getLogger("httpx").setLevel(logging.WARNING)
logging.getLogger("openai").setLevel(logging.WARNING)
logging.getLogger("httpcore").setLevel(logging.WARNING)
DEFAULT_MODEL_ID: str = "reducto/RolmOCR"
ROLMOCR_PROMPT: str = "Return the plain text representation of this document as if you were reading it naturally.\n"
DEFAULT_TARGET_IMAGE_DIM: int = 1024
DEFAULT_API_BASE_URL: str = "http://localhost:8000/v1"
DEFAULT_API_KEY: str = "EMPTY"
DEFAULT_CONCURRENCY_LIMIT: int = 16
def render_pdf_page_to_pil_fitz(
pdf_path: Path,
page_num: int,
target_longest_image_dim: int = DEFAULT_TARGET_IMAGE_DIM,
) -> Optional[Image.Image]:
"""
Renders a single page of a PDF to a PIL Image using PyMuPDF (fitz).
Resizes the image so its longest dimension matches target_longest_image_dim,
but only downscales (does not upscale).
Args:
pdf_path: Path to the PDF file.
page_num: The 1-based page number to render.
target_longest_image_dim: Target size for the longest dimension.
Returns:
A PIL Image object of the rendered page, or None if rendering fails.
"""
doc: Optional[fitz.Document] = None
try:
doc = fitz.open(pdf_path)
if not 0 < page_num <= doc.page_count:
logger.error(
f"Invalid page number {page_num} for {pdf_path.name} "
f"({doc.page_count} pages)."
)
return None
page: fitz.Page = doc.load_page(page_num - 1) # fitz uses 0-based index
page_rect: fitz.Rect = page.rect
width, height = page_rect.width, page_rect.height
if max(width, height) <= 0:
logger.error(
f"Invalid page dimensions ({width}x{height}) for "
f"{pdf_path.name} page {page_num}."
)
return None
zoom_factor: float = 1.0
if max(width, height) > target_longest_image_dim:
zoom_factor = target_longest_image_dim / max(width, height)
matrix: fitz.Matrix = fitz.Matrix(zoom_factor, zoom_factor)
pix: fitz.Pixmap = page.get_pixmap(matrix=matrix, alpha=False)
if pix.width == 0 or pix.height == 0:
logger.error(
f"Rendered pixmap has zero dimension for {pdf_path.name} "
f"page {page_num}."
)
return None
img: Image.Image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
return img
except fitz.fitz.FileNotFoundError:
logger.error(f"PyMuPDF could not find file: {pdf_path}")
return None
except Exception as e:
logger.error(
f"PyMuPDF error rendering {pdf_path.name} page {page_num}: "
f"{type(e).__name__} - {e}"
)
return None
finally:
if doc:
try:
doc.close()
except Exception as e:
logger.warning(f"Error closing PDF {pdf_path.name}: {e}")
def get_pdf_page_count(pdf_path: Path) -> Optional[int]:
"""
Gets the number of pages in a PDF file using pypdf, with fitz fallback.
Args:
pdf_path: Path to the PDF file.
Returns:
The number of pages as an integer, or None if reading fails.
"""
try:
reader = PdfReader(pdf_path, strict=False)
count = len(reader.pages)
if count == 0:
try:
with fitz.open(pdf_path) as doc:
count = doc.page_count
except Exception:
logger.warning(
f"pypdf reported 0 pages, fitz failed to open "
f"{pdf_path.name}. Assuming 0 pages."
)
return 0
return count
except PdfReadError as e:
logger.error(f"pypdf failed to read {pdf_path.name}: {e}. Trying fitz.")
try:
with fitz.open(pdf_path) as doc:
return doc.page_count
except Exception as fitz_e:
logger.error(
f"Both pypdf and fitz failed page count for {pdf_path.name}: {fitz_e}"
)
return None
except FileNotFoundError:
logger.error(f"File not found for page count: {pdf_path}")
return None
except Exception as e:
logger.error(f"Unexpected error getting page count for {pdf_path.name}: {e}")
return None
def encode_pil_to_base64(image: Image.Image, format: str = "PNG") -> str:
"""
Encodes a PIL image object to a base64 string.
Args:
image: The PIL Image object.
format: The image format to use (e.g., "PNG", "JPEG").
Returns:
The base64 encoded string representation of the image.
"""
buffered = io.BytesIO()
image.save(buffered, format=format)
img_byte = buffered.getvalue()
img_base64 = base64.b64encode(img_byte)
return img_base64.decode("utf-8")
async def ocr_page_api(
client: AsyncOpenAI,
model_id: str,
img_base64: str,
page_num: int,
pdf_name: str,
semaphore: asyncio.Semaphore,
temperature: float = 0.1,
max_tokens: int = 4096,
) -> str:
"""
Sends a single page image to the vLLM OpenAI API for OCR asynchronously.
Uses an asyncio.Semaphore to limit the number of concurrent requests.
Args:
client: The initialized AsyncOpenAI client.
model_id: The model identifier for the API call.
img_base64: The base64 encoded string of the page image.
page_num: The 1-based page number (for logging).
pdf_name: The name of the PDF file (for logging).
semaphore: The asyncio.Semaphore to control concurrency.
temperature: Sampling temperature for the model.
max_tokens: Maximum tokens to generate for the page.
Returns:
The extracted text content as a string, or an error marker string
(e.g., "[API_CONNECTION_ERROR]") if an API error occurs.
"""
async with semaphore: # Acquire semaphore before making the API call
try:
response = await client.chat.completions.create(
model=model_id,
messages=[
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{img_base64}"
},
},
{"type": "text", "text": ROLMOCR_PROMPT},
],
}
],
temperature=temperature,
max_tokens=max_tokens,
# Consider adding repetition_penalty if needed, e.g., 1.05
# repetition_penalty=1.05,
)
content = response.choices[0].message.content
return content.strip() if content else "[API_EMPTY_RESPONSE]"
except APIConnectionError as e:
logger.error(
f"API Connect Error page {page_num} ({pdf_name}): {e}. "
f"Is server at {client.base_url} running?"
)
return "[API_CONNECTION_ERROR]"
except RateLimitError as e:
logger.warning(
f"API Rate Limit Error page {page_num} ({pdf_name}): {e}. "
f"Server busy or concurrency too high? Retrying may be needed."
)
# Simple backoff, consider more robust retry logic if needed
await asyncio.sleep(2)
return "[API_RATE_LIMIT_ERROR]"
except APIStatusError as e:
logger.error(
f"API Status Error page {page_num} ({pdf_name}): "
f"Status={e.status_code}, Response={e.response}"
)
return f"[API_STATUS_ERROR_{e.status_code}]"
except Exception as e:
# Catch unexpected errors during the API call itself
logger.exception(f"Unexpected API Error page {page_num} ({pdf_name}): {e}")
return "[API_UNEXPECTED_ERROR]"
def render_and_encode_single_page(
pdf_file: Path, page_num: int, target_image_dim: int, pdf_name: str
) -> tuple:
"""
Renders and encodes a single PDF page in one function for parallel processing.
Args:
pdf_file: Path to the PDF file
page_num: Page number to render (1-based)
target_image_dim: Target size for longest dimension
pdf_name: Name of PDF file (for logging)
Returns:
tuple: (page_num, base64_string or error_message)
"""
pil_image = render_pdf_page_to_pil_fitz(pdf_file, page_num, target_image_dim)
if not pil_image:
logger.warning(f"Failed to render page {page_num} ({pdf_name})")
return page_num, "[PAGE_RENDER_ERROR]"
try:
img_base64 = encode_pil_to_base64(pil_image)
return page_num, img_base64
except Exception as e:
logger.error(f"Failed to encode page {page_num} ({pdf_name}): {e}")
return page_num, "[IMAGE_ENCODE_ERROR]"
# --- Main Processing Logic ---
async def process_directory(
input_dir: str,
output_dir: Optional[str] = None,
model_id: str = DEFAULT_MODEL_ID,
api_base_url: str = DEFAULT_API_BASE_URL,
api_key: str = DEFAULT_API_KEY,
target_image_dim: int = DEFAULT_TARGET_IMAGE_DIM,
max_pages: Optional[int] = None,
temperature: float = 0.1,
max_tokens_per_page: int = 4096,
overwrite: bool = False,
concurrency_limit: int = DEFAULT_CONCURRENCY_LIMIT,
) -> None:
"""
Processes PDF files asynchronously using RolmOCR via vLLM's OpenAI API.
Renders pages, encodes them, sends concurrent API requests for OCR,
combines results, and saves text files.
Args:
input_dir: Path to the directory containing input PDF files.
output_dir: Path to the directory for output .txt files. If None,
creates a directory next to input_dir.
model_id: Model ID for the vLLM server API.
api_base_url: Base URL of the vLLM OpenAI-compatible API endpoint.
api_key: API key for the endpoint (usually 'EMPTY' for local vLLM).
target_image_dim: Target size for the longest dimension of page images.
max_pages: Max pages to process per PDF (None for all pages).
temperature: Sampling temperature for the model (0.0-0.2 recommended).
max_tokens_per_page: Max tokens the model can generate per page.
overwrite: If True, overwrite existing output .txt files.
concurrency_limit: Maximum number of concurrent API requests.
"""
input_path = Path(input_dir).resolve()
assert input_path.is_dir(), (
f"Input directory not found or is not a directory: {input_path}"
)
output_path = (
Path(output_dir).resolve()
if output_dir is not None
else input_path.parent / f"output-pdftotext-{input_path.name}"
)
output_path.mkdir(parents=True, exist_ok=True)
logger.info(f"Input directory: {input_path}")
logger.info(f"Output directory: {output_path}")
logger.info(f"Model API: {model_id} at {api_base_url}")
logger.info(f"Concurrency: {concurrency_limit}")
logger.info(f"Target Image Dim: {target_image_dim}")
logger.info(f"Overwrite: {overwrite}")
client: Optional[AsyncOpenAI] = None
try:
client = AsyncOpenAI(api_key=api_key, base_url=api_base_url)
logger.info(f"AsyncOpenAI client initialized for {api_base_url}")
pdf_files: List[Path] = sorted(list(input_path.glob("*.pdf")))
if not pdf_files:
logger.warning(f"No PDF files found in {input_path}")
return
logger.info(f"Found {len(pdf_files)} PDF files.")
semaphore = asyncio.Semaphore(concurrency_limit)
for pdf_file in tqdm(pdf_files, desc="Processing PDFs", unit="pdf"):
output_txt_path = output_path / (pdf_file.stem + ".txt")
if not overwrite and output_txt_path.exists():
logger.info(f"Skipping {pdf_file.name}, output exists.")
continue
logger.info(f"Starting processing for {pdf_file.name}")
page_count = get_pdf_page_count(pdf_file)
if page_count is None:
logger.warning(f"Skipping {pdf_file.name}, failed to get page count.")
output_txt_path.write_text("[ERROR_READING_PDF]", encoding="utf-8")
continue
if page_count == 0:
logger.warning(f"Skipping {pdf_file.name}, contains 0 pages.")
output_txt_path.write_text("", encoding="utf-8") # Empty file
continue
num_pages_to_process = page_count
if max_pages is not None and 0 < max_pages < page_count:
logger.info(f"Limiting to first {max_pages} pages of {pdf_file.name}")
num_pages_to_process = max_pages
# --- Preprocessing: Render and Encode Pages ---
# This part is synchronous per PDF but happens before async API calls.
# Could be parallelized with multiprocessing for CPU-bound rendering
# if this becomes the bottleneck.
page_render_encode_data: Dict[int, str] = {} # page_num -> base64 or error
logger.debug(
f"Rendering/encoding {num_pages_to_process} pages for {pdf_file.name} in parallel"
)
# Use at most 8 cores or CPU count, whichever is smaller
n_jobs = min(8, os.cpu_count() or 1)
logger.info(f"Using {n_jobs} cores for parallel page rendering")
# Process pages in parallel
parallel_results = Parallel(n_jobs=n_jobs, verbose=1)(
delayed(render_and_encode_single_page)(
pdf_file, page_num, target_image_dim, pdf_file.name
)
for page_num in range(1, num_pages_to_process + 1)
)
# Process the results
page_render_encode_data = {}
valid_pages_for_api = 0
for page_num, result in parallel_results:
page_render_encode_data[page_num] = result
if not result.startswith("["): # Not an error message
valid_pages_for_api += 1
if valid_pages_for_api == 0:
logger.warning(
f"No pages successfully rendered/encoded for {pdf_file.name}. "
"Skipping API calls."
)
# Store only the errors encountered during preprocessing
all_page_texts = {
pn: data for pn, data in page_render_encode_data.items()
}
else:
# --- Asynchronous API Calls ---
tasks: List[Tuple[int, asyncio.Task[str]]] = []
logger.info(
f"Submitting {valid_pages_for_api} pages to API for {pdf_file.name}"
)
for page_num in range(1, num_pages_to_process + 1):
img_data = page_render_encode_data.get(page_num)
# Only create tasks for pages that were successfully processed
if img_data and not img_data.startswith("["):
task = asyncio.create_task(
ocr_page_api(
client=client,
model_id=model_id,
img_base64=img_data,
page_num=page_num,
pdf_name=pdf_file.name,
semaphore=semaphore,
temperature=temperature,
max_tokens=max_tokens_per_page,
),
name=f"OCR_{pdf_file.stem}_p{page_num}", # Optional: name tasks
)
tasks.append((page_num, task))
# Run tasks concurrently and display progress with tqdm.asyncio
api_results: List[str] = await tqdm_asyncio.gather(
*(task for _, task in tasks),
desc=f" OCR Pages ({pdf_file.name[:20]})",
unit="page",
leave=False,
)
# --- Combine Results ---
all_page_texts: Dict[int, str] = {}
# First, add back any errors from the preprocessing stage
for pn, data in page_render_encode_data.items():
if data.startswith("["):
all_page_texts[pn] = data
# Then, add the results from the successful API calls
for i, (page_num, _) in enumerate(tasks):
all_page_texts[page_num] = api_results[i]
if not all_page_texts:
logger.warning(f"No text results generated for {pdf_file.name}.")
output_txt_path.write_text("", encoding="utf-8")
continue
# Create a regex pattern that matches error markers:
# - Enclosed in brackets [...]
# - Contains uppercase letters, numbers, and underscores inside
ERROR_PATTERN = re.compile(r"^\s*\[[A-Z0-9_]+\]\s*$")
# Ensure pages are ordered correctly, substituting placeholders if needed
ordered_texts: List[str] = [
all_page_texts.get(pn, f"[PAGE_{pn}_MISSING_UNEXPECTEDLY]")
for pn in range(1, num_pages_to_process + 1)
]
# Filter out error messages and placeholders
filtered_texts: List[str] = [
text
for text in ordered_texts
if text.strip() and not ERROR_PATTERN.match(text.strip())
]
# If we filtered out everything, log a warning
if not filtered_texts:
logger.warning(f"All pages were filtered out for {pdf_file.name}.")
output_txt_path.write_text("", encoding="utf-8") # Empty file
continue
# Use form feed character (\f) as page separator
final_text: str = "\n\f\n".join(ordered_texts)
final_text = mdformat.text(
final_text,
options={
"number": True,
"wrap": "no",
},
)
try:
output_txt_path.write_text(final_text, encoding="utf-8")
logger.info(f"Successfully wrote output: {output_txt_path.name}")
except Exception as e:
logger.error(f"Failed to write output file {output_txt_path}: {e}")
except Exception as e:
logger.exception(f"An unexpected error occurred during processing: {e}")
finally:
if client:
await client.close()
logger.info("AsyncOpenAI client closed.")
logger.info("Processing run finished.")
def main(**kwargs: Any) -> None:
"""
Command-line entry point wrapper to run the async processing function.
Uses fire library to handle command-line arguments. Any argument accepted
by `process_directory` can be passed via the command line, e.g.,
`--input_dir ./pdfs --max_pages 5`.
Args:
**kwargs: Arguments passed from the command line via fire.
"""
# If running in an environment like Jupyter that already has an event loop,
# you might need nest_asyncio:
# try:
# import nest_asyncio
# nest_asyncio.apply()
# except ImportError:
# pass # Not needed in standard script execution
try:
asyncio.run(process_directory(**kwargs))
except KeyboardInterrupt:
logger.info("Processing interrupted by user.")
if __name__ == "__main__":
fire.Fire(main)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment