Last active
April 7, 2025 22:51
-
-
Save pszemraj/c3b7a39c78acf0aba974744920d741d6 to your computer and use it in GitHub Desktop.
Standalone Asynchronous RolmOCR Inference Script using vLLM and PyMuPDF.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
""" | |
Standalone Asynchronous RolmOCR Inference Script using vLLM and PyMuPDF. | |
This script processes PDF files from an input directory using the | |
reducto/RolmOCR model served locally by vLLM via its OpenAI-compatible API. | |
It renders each page, sends API requests concurrently for OCR, extracts plain | |
text, and saves the combined text for each PDF into a corresponding .txt file | |
in the specified output directory. | |
This version uses asyncio and the AsyncOpenAI client to significantly speed up | |
processing by sending multiple page OCR requests to the vLLM server concurrently. | |
**IMPORTANT:** Requires a separate vLLM server running with the RolmOCR model. | |
Start the server BEFORE running this script, for example: | |
vllm serve reducto/RolmOCR --max-num-seqs 256 --gpu-memory-utilization 0.9 | |
Dependencies (vLLM - see vLLM docs for specific CUDA versions): | |
pip install ninja vllm flash-attn | |
Dependencies (Script): | |
pip install "openai>=1.0" PyMuPDF Pillow fire tqdm pypdf mdformat "tqdm[asyncio]" joblib | |
Example Usage: | |
# 1. Start the vLLM server in a separate terminal: | |
# vllm serve reducto/RolmOCR | |
# 2. Run this script: | |
python async_pipeline.py \ | |
--input_dir ./my_pdfs \ | |
--output_dir ./output_text \ | |
--model_id reducto/RolmOCR \ | |
--max_pages 100 \ | |
--overwrite \ | |
--api_base_url http://localhost:8000/v1 \ | |
--concurrency_limit 16 | |
""" | |
import asyncio | |
import base64 | |
import io | |
import logging | |
import os | |
import re | |
from pathlib import Path | |
from typing import Any, Dict, List, Optional, Tuple | |
import fire | |
import mdformat | |
from joblib import Parallel, delayed | |
from PIL import Image | |
from pypdf import PdfReader | |
from pypdf.errors import PdfReadError | |
from tqdm import tqdm | |
from tqdm.asyncio import tqdm_asyncio | |
try: | |
from openai import APIConnectionError, APIStatusError, AsyncOpenAI, RateLimitError | |
except ImportError: | |
print("=" * 80) | |
print("ERROR: openai library >= 1.0 not found.") | |
print("Please install it: pip install 'openai>=1.0'") | |
print("=" * 80) | |
exit(1) | |
try: | |
import fitz # PyMuPDF | |
except ImportError: | |
print("=" * 80) | |
print("ERROR: PyMuPDF library not found.") | |
print("Please install it: pip install PyMuPDF") | |
print("=" * 80) | |
exit(1) | |
# --- Configuration --- | |
logging.basicConfig( | |
level=logging.INFO, | |
format="%(asctime)s - %(levelname)s - [%(funcName)s] %(message)s", | |
datefmt="%Y-%m-%d %H:%M:%S", | |
) | |
logger = logging.getLogger(__name__) | |
# Reduce noise from underlying libraries | |
logging.getLogger("httpx").setLevel(logging.WARNING) | |
logging.getLogger("openai").setLevel(logging.WARNING) | |
logging.getLogger("httpcore").setLevel(logging.WARNING) | |
DEFAULT_MODEL_ID: str = "reducto/RolmOCR" | |
ROLMOCR_PROMPT: str = "Return the plain text representation of this document as if you were reading it naturally.\n" | |
DEFAULT_TARGET_IMAGE_DIM: int = 1024 | |
DEFAULT_API_BASE_URL: str = "http://localhost:8000/v1" | |
DEFAULT_API_KEY: str = "EMPTY" | |
DEFAULT_CONCURRENCY_LIMIT: int = 16 | |
def render_pdf_page_to_pil_fitz( | |
pdf_path: Path, | |
page_num: int, | |
target_longest_image_dim: int = DEFAULT_TARGET_IMAGE_DIM, | |
) -> Optional[Image.Image]: | |
""" | |
Renders a single page of a PDF to a PIL Image using PyMuPDF (fitz). | |
Resizes the image so its longest dimension matches target_longest_image_dim, | |
but only downscales (does not upscale). | |
Args: | |
pdf_path: Path to the PDF file. | |
page_num: The 1-based page number to render. | |
target_longest_image_dim: Target size for the longest dimension. | |
Returns: | |
A PIL Image object of the rendered page, or None if rendering fails. | |
""" | |
doc: Optional[fitz.Document] = None | |
try: | |
doc = fitz.open(pdf_path) | |
if not 0 < page_num <= doc.page_count: | |
logger.error( | |
f"Invalid page number {page_num} for {pdf_path.name} " | |
f"({doc.page_count} pages)." | |
) | |
return None | |
page: fitz.Page = doc.load_page(page_num - 1) # fitz uses 0-based index | |
page_rect: fitz.Rect = page.rect | |
width, height = page_rect.width, page_rect.height | |
if max(width, height) <= 0: | |
logger.error( | |
f"Invalid page dimensions ({width}x{height}) for " | |
f"{pdf_path.name} page {page_num}." | |
) | |
return None | |
zoom_factor: float = 1.0 | |
if max(width, height) > target_longest_image_dim: | |
zoom_factor = target_longest_image_dim / max(width, height) | |
matrix: fitz.Matrix = fitz.Matrix(zoom_factor, zoom_factor) | |
pix: fitz.Pixmap = page.get_pixmap(matrix=matrix, alpha=False) | |
if pix.width == 0 or pix.height == 0: | |
logger.error( | |
f"Rendered pixmap has zero dimension for {pdf_path.name} " | |
f"page {page_num}." | |
) | |
return None | |
img: Image.Image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples) | |
return img | |
except fitz.fitz.FileNotFoundError: | |
logger.error(f"PyMuPDF could not find file: {pdf_path}") | |
return None | |
except Exception as e: | |
logger.error( | |
f"PyMuPDF error rendering {pdf_path.name} page {page_num}: " | |
f"{type(e).__name__} - {e}" | |
) | |
return None | |
finally: | |
if doc: | |
try: | |
doc.close() | |
except Exception as e: | |
logger.warning(f"Error closing PDF {pdf_path.name}: {e}") | |
def get_pdf_page_count(pdf_path: Path) -> Optional[int]: | |
""" | |
Gets the number of pages in a PDF file using pypdf, with fitz fallback. | |
Args: | |
pdf_path: Path to the PDF file. | |
Returns: | |
The number of pages as an integer, or None if reading fails. | |
""" | |
try: | |
reader = PdfReader(pdf_path, strict=False) | |
count = len(reader.pages) | |
if count == 0: | |
try: | |
with fitz.open(pdf_path) as doc: | |
count = doc.page_count | |
except Exception: | |
logger.warning( | |
f"pypdf reported 0 pages, fitz failed to open " | |
f"{pdf_path.name}. Assuming 0 pages." | |
) | |
return 0 | |
return count | |
except PdfReadError as e: | |
logger.error(f"pypdf failed to read {pdf_path.name}: {e}. Trying fitz.") | |
try: | |
with fitz.open(pdf_path) as doc: | |
return doc.page_count | |
except Exception as fitz_e: | |
logger.error( | |
f"Both pypdf and fitz failed page count for {pdf_path.name}: {fitz_e}" | |
) | |
return None | |
except FileNotFoundError: | |
logger.error(f"File not found for page count: {pdf_path}") | |
return None | |
except Exception as e: | |
logger.error(f"Unexpected error getting page count for {pdf_path.name}: {e}") | |
return None | |
def encode_pil_to_base64(image: Image.Image, format: str = "PNG") -> str: | |
""" | |
Encodes a PIL image object to a base64 string. | |
Args: | |
image: The PIL Image object. | |
format: The image format to use (e.g., "PNG", "JPEG"). | |
Returns: | |
The base64 encoded string representation of the image. | |
""" | |
buffered = io.BytesIO() | |
image.save(buffered, format=format) | |
img_byte = buffered.getvalue() | |
img_base64 = base64.b64encode(img_byte) | |
return img_base64.decode("utf-8") | |
async def ocr_page_api( | |
client: AsyncOpenAI, | |
model_id: str, | |
img_base64: str, | |
page_num: int, | |
pdf_name: str, | |
semaphore: asyncio.Semaphore, | |
temperature: float = 0.1, | |
max_tokens: int = 4096, | |
) -> str: | |
""" | |
Sends a single page image to the vLLM OpenAI API for OCR asynchronously. | |
Uses an asyncio.Semaphore to limit the number of concurrent requests. | |
Args: | |
client: The initialized AsyncOpenAI client. | |
model_id: The model identifier for the API call. | |
img_base64: The base64 encoded string of the page image. | |
page_num: The 1-based page number (for logging). | |
pdf_name: The name of the PDF file (for logging). | |
semaphore: The asyncio.Semaphore to control concurrency. | |
temperature: Sampling temperature for the model. | |
max_tokens: Maximum tokens to generate for the page. | |
Returns: | |
The extracted text content as a string, or an error marker string | |
(e.g., "[API_CONNECTION_ERROR]") if an API error occurs. | |
""" | |
async with semaphore: # Acquire semaphore before making the API call | |
try: | |
response = await client.chat.completions.create( | |
model=model_id, | |
messages=[ | |
{ | |
"role": "user", | |
"content": [ | |
{ | |
"type": "image_url", | |
"image_url": { | |
"url": f"data:image/png;base64,{img_base64}" | |
}, | |
}, | |
{"type": "text", "text": ROLMOCR_PROMPT}, | |
], | |
} | |
], | |
temperature=temperature, | |
max_tokens=max_tokens, | |
# Consider adding repetition_penalty if needed, e.g., 1.05 | |
# repetition_penalty=1.05, | |
) | |
content = response.choices[0].message.content | |
return content.strip() if content else "[API_EMPTY_RESPONSE]" | |
except APIConnectionError as e: | |
logger.error( | |
f"API Connect Error page {page_num} ({pdf_name}): {e}. " | |
f"Is server at {client.base_url} running?" | |
) | |
return "[API_CONNECTION_ERROR]" | |
except RateLimitError as e: | |
logger.warning( | |
f"API Rate Limit Error page {page_num} ({pdf_name}): {e}. " | |
f"Server busy or concurrency too high? Retrying may be needed." | |
) | |
# Simple backoff, consider more robust retry logic if needed | |
await asyncio.sleep(2) | |
return "[API_RATE_LIMIT_ERROR]" | |
except APIStatusError as e: | |
logger.error( | |
f"API Status Error page {page_num} ({pdf_name}): " | |
f"Status={e.status_code}, Response={e.response}" | |
) | |
return f"[API_STATUS_ERROR_{e.status_code}]" | |
except Exception as e: | |
# Catch unexpected errors during the API call itself | |
logger.exception(f"Unexpected API Error page {page_num} ({pdf_name}): {e}") | |
return "[API_UNEXPECTED_ERROR]" | |
def render_and_encode_single_page( | |
pdf_file: Path, page_num: int, target_image_dim: int, pdf_name: str | |
) -> tuple: | |
""" | |
Renders and encodes a single PDF page in one function for parallel processing. | |
Args: | |
pdf_file: Path to the PDF file | |
page_num: Page number to render (1-based) | |
target_image_dim: Target size for longest dimension | |
pdf_name: Name of PDF file (for logging) | |
Returns: | |
tuple: (page_num, base64_string or error_message) | |
""" | |
pil_image = render_pdf_page_to_pil_fitz(pdf_file, page_num, target_image_dim) | |
if not pil_image: | |
logger.warning(f"Failed to render page {page_num} ({pdf_name})") | |
return page_num, "[PAGE_RENDER_ERROR]" | |
try: | |
img_base64 = encode_pil_to_base64(pil_image) | |
return page_num, img_base64 | |
except Exception as e: | |
logger.error(f"Failed to encode page {page_num} ({pdf_name}): {e}") | |
return page_num, "[IMAGE_ENCODE_ERROR]" | |
# --- Main Processing Logic --- | |
async def process_directory( | |
input_dir: str, | |
output_dir: Optional[str] = None, | |
model_id: str = DEFAULT_MODEL_ID, | |
api_base_url: str = DEFAULT_API_BASE_URL, | |
api_key: str = DEFAULT_API_KEY, | |
target_image_dim: int = DEFAULT_TARGET_IMAGE_DIM, | |
max_pages: Optional[int] = None, | |
temperature: float = 0.1, | |
max_tokens_per_page: int = 4096, | |
overwrite: bool = False, | |
concurrency_limit: int = DEFAULT_CONCURRENCY_LIMIT, | |
) -> None: | |
""" | |
Processes PDF files asynchronously using RolmOCR via vLLM's OpenAI API. | |
Renders pages, encodes them, sends concurrent API requests for OCR, | |
combines results, and saves text files. | |
Args: | |
input_dir: Path to the directory containing input PDF files. | |
output_dir: Path to the directory for output .txt files. If None, | |
creates a directory next to input_dir. | |
model_id: Model ID for the vLLM server API. | |
api_base_url: Base URL of the vLLM OpenAI-compatible API endpoint. | |
api_key: API key for the endpoint (usually 'EMPTY' for local vLLM). | |
target_image_dim: Target size for the longest dimension of page images. | |
max_pages: Max pages to process per PDF (None for all pages). | |
temperature: Sampling temperature for the model (0.0-0.2 recommended). | |
max_tokens_per_page: Max tokens the model can generate per page. | |
overwrite: If True, overwrite existing output .txt files. | |
concurrency_limit: Maximum number of concurrent API requests. | |
""" | |
input_path = Path(input_dir).resolve() | |
assert input_path.is_dir(), ( | |
f"Input directory not found or is not a directory: {input_path}" | |
) | |
output_path = ( | |
Path(output_dir).resolve() | |
if output_dir is not None | |
else input_path.parent / f"output-pdftotext-{input_path.name}" | |
) | |
output_path.mkdir(parents=True, exist_ok=True) | |
logger.info(f"Input directory: {input_path}") | |
logger.info(f"Output directory: {output_path}") | |
logger.info(f"Model API: {model_id} at {api_base_url}") | |
logger.info(f"Concurrency: {concurrency_limit}") | |
logger.info(f"Target Image Dim: {target_image_dim}") | |
logger.info(f"Overwrite: {overwrite}") | |
client: Optional[AsyncOpenAI] = None | |
try: | |
client = AsyncOpenAI(api_key=api_key, base_url=api_base_url) | |
logger.info(f"AsyncOpenAI client initialized for {api_base_url}") | |
pdf_files: List[Path] = sorted(list(input_path.glob("*.pdf"))) | |
if not pdf_files: | |
logger.warning(f"No PDF files found in {input_path}") | |
return | |
logger.info(f"Found {len(pdf_files)} PDF files.") | |
semaphore = asyncio.Semaphore(concurrency_limit) | |
for pdf_file in tqdm(pdf_files, desc="Processing PDFs", unit="pdf"): | |
output_txt_path = output_path / (pdf_file.stem + ".txt") | |
if not overwrite and output_txt_path.exists(): | |
logger.info(f"Skipping {pdf_file.name}, output exists.") | |
continue | |
logger.info(f"Starting processing for {pdf_file.name}") | |
page_count = get_pdf_page_count(pdf_file) | |
if page_count is None: | |
logger.warning(f"Skipping {pdf_file.name}, failed to get page count.") | |
output_txt_path.write_text("[ERROR_READING_PDF]", encoding="utf-8") | |
continue | |
if page_count == 0: | |
logger.warning(f"Skipping {pdf_file.name}, contains 0 pages.") | |
output_txt_path.write_text("", encoding="utf-8") # Empty file | |
continue | |
num_pages_to_process = page_count | |
if max_pages is not None and 0 < max_pages < page_count: | |
logger.info(f"Limiting to first {max_pages} pages of {pdf_file.name}") | |
num_pages_to_process = max_pages | |
# --- Preprocessing: Render and Encode Pages --- | |
# This part is synchronous per PDF but happens before async API calls. | |
# Could be parallelized with multiprocessing for CPU-bound rendering | |
# if this becomes the bottleneck. | |
page_render_encode_data: Dict[int, str] = {} # page_num -> base64 or error | |
logger.debug( | |
f"Rendering/encoding {num_pages_to_process} pages for {pdf_file.name} in parallel" | |
) | |
# Use at most 8 cores or CPU count, whichever is smaller | |
n_jobs = min(8, os.cpu_count() or 1) | |
logger.info(f"Using {n_jobs} cores for parallel page rendering") | |
# Process pages in parallel | |
parallel_results = Parallel(n_jobs=n_jobs, verbose=1)( | |
delayed(render_and_encode_single_page)( | |
pdf_file, page_num, target_image_dim, pdf_file.name | |
) | |
for page_num in range(1, num_pages_to_process + 1) | |
) | |
# Process the results | |
page_render_encode_data = {} | |
valid_pages_for_api = 0 | |
for page_num, result in parallel_results: | |
page_render_encode_data[page_num] = result | |
if not result.startswith("["): # Not an error message | |
valid_pages_for_api += 1 | |
if valid_pages_for_api == 0: | |
logger.warning( | |
f"No pages successfully rendered/encoded for {pdf_file.name}. " | |
"Skipping API calls." | |
) | |
# Store only the errors encountered during preprocessing | |
all_page_texts = { | |
pn: data for pn, data in page_render_encode_data.items() | |
} | |
else: | |
# --- Asynchronous API Calls --- | |
tasks: List[Tuple[int, asyncio.Task[str]]] = [] | |
logger.info( | |
f"Submitting {valid_pages_for_api} pages to API for {pdf_file.name}" | |
) | |
for page_num in range(1, num_pages_to_process + 1): | |
img_data = page_render_encode_data.get(page_num) | |
# Only create tasks for pages that were successfully processed | |
if img_data and not img_data.startswith("["): | |
task = asyncio.create_task( | |
ocr_page_api( | |
client=client, | |
model_id=model_id, | |
img_base64=img_data, | |
page_num=page_num, | |
pdf_name=pdf_file.name, | |
semaphore=semaphore, | |
temperature=temperature, | |
max_tokens=max_tokens_per_page, | |
), | |
name=f"OCR_{pdf_file.stem}_p{page_num}", # Optional: name tasks | |
) | |
tasks.append((page_num, task)) | |
# Run tasks concurrently and display progress with tqdm.asyncio | |
api_results: List[str] = await tqdm_asyncio.gather( | |
*(task for _, task in tasks), | |
desc=f" OCR Pages ({pdf_file.name[:20]})", | |
unit="page", | |
leave=False, | |
) | |
# --- Combine Results --- | |
all_page_texts: Dict[int, str] = {} | |
# First, add back any errors from the preprocessing stage | |
for pn, data in page_render_encode_data.items(): | |
if data.startswith("["): | |
all_page_texts[pn] = data | |
# Then, add the results from the successful API calls | |
for i, (page_num, _) in enumerate(tasks): | |
all_page_texts[page_num] = api_results[i] | |
if not all_page_texts: | |
logger.warning(f"No text results generated for {pdf_file.name}.") | |
output_txt_path.write_text("", encoding="utf-8") | |
continue | |
# Create a regex pattern that matches error markers: | |
# - Enclosed in brackets [...] | |
# - Contains uppercase letters, numbers, and underscores inside | |
ERROR_PATTERN = re.compile(r"^\s*\[[A-Z0-9_]+\]\s*$") | |
# Ensure pages are ordered correctly, substituting placeholders if needed | |
ordered_texts: List[str] = [ | |
all_page_texts.get(pn, f"[PAGE_{pn}_MISSING_UNEXPECTEDLY]") | |
for pn in range(1, num_pages_to_process + 1) | |
] | |
# Filter out error messages and placeholders | |
filtered_texts: List[str] = [ | |
text | |
for text in ordered_texts | |
if text.strip() and not ERROR_PATTERN.match(text.strip()) | |
] | |
# If we filtered out everything, log a warning | |
if not filtered_texts: | |
logger.warning(f"All pages were filtered out for {pdf_file.name}.") | |
output_txt_path.write_text("", encoding="utf-8") # Empty file | |
continue | |
# Use form feed character (\f) as page separator | |
final_text: str = "\n\f\n".join(ordered_texts) | |
final_text = mdformat.text( | |
final_text, | |
options={ | |
"number": True, | |
"wrap": "no", | |
}, | |
) | |
try: | |
output_txt_path.write_text(final_text, encoding="utf-8") | |
logger.info(f"Successfully wrote output: {output_txt_path.name}") | |
except Exception as e: | |
logger.error(f"Failed to write output file {output_txt_path}: {e}") | |
except Exception as e: | |
logger.exception(f"An unexpected error occurred during processing: {e}") | |
finally: | |
if client: | |
await client.close() | |
logger.info("AsyncOpenAI client closed.") | |
logger.info("Processing run finished.") | |
def main(**kwargs: Any) -> None: | |
""" | |
Command-line entry point wrapper to run the async processing function. | |
Uses fire library to handle command-line arguments. Any argument accepted | |
by `process_directory` can be passed via the command line, e.g., | |
`--input_dir ./pdfs --max_pages 5`. | |
Args: | |
**kwargs: Arguments passed from the command line via fire. | |
""" | |
# If running in an environment like Jupyter that already has an event loop, | |
# you might need nest_asyncio: | |
# try: | |
# import nest_asyncio | |
# nest_asyncio.apply() | |
# except ImportError: | |
# pass # Not needed in standard script execution | |
try: | |
asyncio.run(process_directory(**kwargs)) | |
except KeyboardInterrupt: | |
logger.info("Processing interrupted by user.") | |
if __name__ == "__main__": | |
fire.Fire(main) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment