Created
July 13, 2025 18:46
-
-
Save aelaguiz/515f37c7594d278ab28c3939cafcdbd1 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Invoke like this: | |
➜ python scripts/run_test.py RUST_LOG=debug cargo test -p ohh-parser --test sample_hands parse_all_samples -- --nocapture | |
➜ cat failure_context.txt | pbcopy | |
<paste into grok> | |
run_test.py: | |
``` | |
import subprocess | |
import re | |
import pyperclip | |
from pathlib import Path | |
import sys | |
import argparse | |
import os | |
from openai import OpenAI | |
from pydantic import BaseModel, Field | |
import json | |
import logging | |
from typing import Any, Dict | |
import tiktoken | |
# Config Defaults | |
OUTPUT_FILE = "failure_context.txt" | |
OPENAI_MODEL = "gpt-4o-mini" # Updated to use GPT-4o-mini | |
USE_LLM = True # Default to using LLM for filtering | |
MAX_FILES = None # No limit by default | |
logger = logging.getLogger(__name__) | |
class LLMResponse: | |
"""Standardized response object from LLM completions.""" | |
content: str | |
raw_response: Any = None | |
model: str = "" | |
usage: Dict[str, int] = None | |
def __init__(self, content, raw_response=None, model="", usage=None): | |
self.content = content | |
self.raw_response = raw_response | |
self.model = model | |
self.usage = usage or {} | |
class LLMClient: | |
"""A simple client for LLM interactions with schema validation.""" | |
def __init__(self, default_model: str): | |
self.default_model = default_model | |
self.client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY")) | |
def complete( | |
self, | |
messages: list[dict[str, str]], | |
model: str = None, | |
response_schema: BaseModel = None, | |
): | |
model = model or self.default_model | |
kwargs = {} | |
if response_schema: | |
kwargs["response_format"] = {"type": "json_object"} | |
try: | |
response = self.client.chat.completions.create( | |
model=model, | |
messages=messages, | |
**kwargs | |
) | |
content = response.choices[0].message.content | |
if response_schema and content: | |
try: | |
# Strip markdown formatting if present | |
if content.strip().startswith("```"): | |
content = re.sub(r'^```(?:json)?\s*\n', '', content.strip()) | |
content = re.sub(r'\n```\s*$', '', content) | |
parsed_content = response_schema.model_validate_json(content) | |
content = parsed_content | |
except Exception as e: | |
logger.error(f"Failed to parse response with schema: {e}") | |
# Keep raw content if parsing fails | |
return LLMResponse( | |
content=content, | |
raw_response=response, | |
model=response.model, | |
usage=response.usage.dict() if response.usage else {}, | |
) | |
except Exception as e: | |
logger.error(f"LLM completion failed: {e}") | |
raise | |
def parse_args(): | |
parser = argparse.ArgumentParser(description="Test failure collector") | |
parser.add_argument("--no-llm", action="store_true", help="Skip LLM filtering") | |
parser.add_argument("--max-files", type=int, help="Max number of files to bundle") | |
parser.add_argument("command", nargs=argparse.REMAINDER, help="The test command to run") | |
args = parser.parse_args() | |
if not args.command: | |
parser.error("No command provided") | |
return args | |
def run_command(command_parts): | |
command_str = " ".join(command_parts) | |
result = subprocess.run(command_str, shell=True, capture_output=True, text=True) | |
output = result.stdout + result.stderr | |
print(output) | |
return command_str, output, result.returncode | |
def parse_file_paths(output): | |
# Regex for Rust paths (extend for Python .py if needed) | |
pattern = r'((?:\./)?(?:[a-zA-Z0-9_-]+/)*[a-zA-Z0-9_-]+\.rs)(?::\d+(?::\d+)?)?' | |
paths = re.findall(pattern, output) | |
return sorted(set(paths)) # Unique | |
def get_all_rs_files(): | |
try: | |
tracked = subprocess.check_output(['git', 'ls-files', '--', '*.rs'], text=True).splitlines() | |
untracked = subprocess.check_output(['git', 'ls-files', '--others', '--exclude-standard', '--', '*.rs'], text=True).splitlines() | |
all_files = sorted(set(tracked + untracked)) | |
return all_files | |
except Exception as e: | |
print(f"Failed to get git files: {e}. Falling back to no list.") | |
return [] | |
class FileSelection(BaseModel): | |
selected_files: list[str] = Field(..., description="List of selected file paths") | |
def query_llm_for_relevant_files(error_output, candidate_files, max_files, all_rs_files): | |
prompt = f""" | |
Error message: {error_output[:2000]} | |
Candidate file paths mentioned in error: {', '.join(candidate_files)} | |
Full list of .rs files in the repository (use these for output paths): {', '.join(all_rs_files)} | |
Select up to {max_files or 'all'} most relevant files to understand/fix this error. Prioritize test files, panicked modules, and directly referenced code. Drop boilerplate. | |
For each selected file, output the full file path from the full list that is most likely the one referred to in the candidates or error. | |
Output as JSON object with key "selected_files" containing the array of full file paths, no explanations. | |
""" | |
try: | |
llm_client = LLMClient(default_model=OPENAI_MODEL) | |
messages = [{"role": "user", "content": prompt}] | |
response = llm_client.complete(messages=messages, response_schema=FileSelection) | |
print("DEBUG: LLM response:\n", response.content) | |
print("DEBUG: Raw selected files from LLM:", response.content.selected_files) | |
selected = response.content.selected_files | |
selected = [f.strip() for f in selected if f.strip() in all_rs_files] | |
print("DEBUG: Filtered selected files:", selected) | |
return selected[:max_files] if max_files else selected | |
except Exception as e: | |
print(f"LLM query failed: {e}. Falling back to all files.") | |
return candidate_files[:max_files] if max_files else candidate_files | |
def query_llm_for_additional_files(error_output, selected_files, all_rs_files, remaining_slots): | |
# Build content of selected files | |
source_files_content = "" | |
for rel_path in selected_files: | |
full_path = Path(rel_path) | |
if full_path.is_file(): | |
source_files_content += format_file(str(full_path)) | |
else: | |
print(f"Failed to locate file on disk: {rel_path}") | |
prompt = f""" | |
Error message: {error_output[:2000]} | |
Already selected files: {', '.join(selected_files)} | |
Content of selected files: | |
{source_files_content} | |
Full list of .rs files in the repository (use these for output paths): {', '.join(all_rs_files)} | |
Based on the error and the content of the already selected files, select additional relevant files that would help understand/fix this error. Prioritize dependencies, called functions, etc., not already selected. | |
Select up to {remaining_slots or 'all'} additional files. | |
Output as JSON object with key "selected_files" containing the array of full file paths, no explanations. | |
""" | |
try: | |
llm_client = LLMClient(default_model=OPENAI_MODEL) | |
messages = [{"role": "user", "content": prompt}] | |
response = llm_client.complete(messages=messages, response_schema=FileSelection) | |
print("DEBUG: LLM additional response:\n", response.content) | |
print("DEBUG: Raw additional files from LLM:", response.content.selected_files) | |
additional = response.content.selected_files | |
additional = [f.strip() for f in additional if f.strip() in all_rs_files and f.strip() not in selected_files] | |
print("DEBUG: Filtered additional files:", additional) | |
return additional | |
except Exception as e: | |
print(f"Additional LLM query failed: {e}. No additional files added.") | |
return [] | |
def format_file(file_path): | |
try: | |
with open(file_path, "r", encoding="utf-8", errors="ignore") as f: | |
code = f.read() | |
# Determine the language based on the file extension | |
ext = os.path.splitext(file_path)[1].lower() | |
if os.path.basename(file_path) == "Makefile": | |
language = "makefile" | |
elif ext == ".py": | |
language = "python" | |
elif ext == ".go": | |
language = "go" | |
elif ext == ".rst": | |
language = "rst" | |
elif ext == ".html": | |
language = "html" | |
elif ext == ".c": | |
language = "c" | |
elif ext == ".cc": | |
language = "cpp" | |
elif ext == ".hxx": | |
language = "cpp" | |
elif ext == ".cpp": | |
language = "cpp" | |
elif ext == ".h": | |
language = "cpp" | |
elif ext == ".rs": | |
language = "rust" | |
elif ext == ".js" or ext == ".ts": | |
language = "javascript" if ext == ".js" else "typescript" | |
else: | |
language = "" | |
return f"{file_path}:\n```{language}\n{code}\n```\n\n" | |
except: | |
return "" | |
def bundle_context(output, files, command_str): | |
# Try to read README.md | |
readme_path = Path("README.md") | |
readme_content = "" | |
if readme_path.is_file(): | |
try: | |
with open(readme_path, "r", encoding="utf-8", errors="ignore") as f: | |
readme_content = f.read() | |
except: | |
readme_content = "(Could not read README.md)" | |
else: | |
readme_content = "(No README.md found)" | |
# Build source files content | |
source_files_content = "" | |
for rel_path in files: | |
full_path = Path(rel_path) | |
if full_path.is_file(): | |
source_files_content += format_file(str(full_path)) | |
else: | |
print(f"Failed to locate file on disk: {rel_path}") | |
content = f"""You are an expert coding assistant specialized in identifying root causes and proposing deep, elegant fixes. | |
Here is the project README: | |
{readme_content} | |
A test command failed: {command_str} | |
With output: | |
``` | |
{output} | |
``` | |
Review the failure. | |
Here are relevant source files: | |
{source_files_content} | |
Identify the root cause of the problem and propose an elegant fix, not a workaround. In your analysis and proposed fix, always specify the exact file paths and function names you are referring to, to make your reasoning easy to follow. If you need more files or context, ask.""" | |
return content | |
# Main | |
args = parse_args() | |
USE_LLM = not args.no_llm | |
MAX_FILES = args.max_files | |
command_str, output, exit_code = run_command(args.command) | |
if exit_code == 0: | |
print("Tests passed!") | |
else: | |
candidates = parse_file_paths(output) | |
if candidates: | |
all_rs_files = get_all_rs_files() | |
selected_files = query_llm_for_relevant_files(output, candidates, MAX_FILES, all_rs_files) if USE_LLM else (candidates[:MAX_FILES] if MAX_FILES else candidates) | |
if USE_LLM: | |
final_files = selected_files[:] | |
iteration = 0 | |
max_iterations = 5 # Prevent infinite loop | |
while iteration < max_iterations: | |
iteration += 1 | |
remaining_slots = (MAX_FILES - len(final_files)) if MAX_FILES is not None else None | |
if remaining_slots is not None and remaining_slots <= 0: | |
break | |
additional_files = query_llm_for_additional_files(output, final_files, all_rs_files, remaining_slots) | |
if not additional_files: | |
break | |
final_files += additional_files | |
if MAX_FILES is not None: | |
final_files = final_files[:MAX_FILES] | |
else: | |
final_files = selected_files | |
context = bundle_context(output, final_files, command_str) | |
with open(OUTPUT_FILE, "w") as f: | |
f.write(context) | |
pyperclip.copy(context) | |
encoding = tiktoken.encoding_for_model(OPENAI_MODEL) | |
token_count = len(encoding.encode(context)) | |
print(f"Failure context ({len(final_files)} files) written to {OUTPUT_FILE} and copied to clipboard.") | |
print(f"Estimated token count: {token_count}") | |
else: | |
print("No file paths found in output.") | |
``` |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment