Skip to content

Instantly share code, notes, and snippets.

@ChakshuGautam
Last active July 3, 2024 08:07
Show Gist options
  • Save ChakshuGautam/2b71b2b01dbb3dfb710c0c2fe51f4f1d to your computer and use it in GitHub Desktop.
Save ChakshuGautam/2b71b2b01dbb3dfb710c0c2fe51f4f1d to your computer and use it in GitHub Desktop.
BERT Embedding based chunking of texts
# --------------------------------------------------------------------- #
# Chunking Mechanism #
import numpy as np
from transformers import BertTokenizer, BertModel
from sklearn.metrics.pairwise import cosine_similarity
def calculate_embedding_difference(embeddings):
return [1 - cosine_similarity(embeddings[i].reshape(1, -1), embeddings[i + 1].reshape(1, -1))[0][0] for i in range(len(embeddings) - 1)]
# Initialize the tokenizer and model
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')
# Chunk the text
text = "your_text_here"
chunk_size = 10
chunks = chunk_text_by_words(text, chunk_size)
# Generate embeddings for each chunk
embeddings = []
for chunk in chunks:
inputs = tokenizer(chunk, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
embeddings.append(outputs.pooler_output.squeeze().numpy())
# Calculate differences between consecutive embeddings
differences = calculate_embedding_difference(embeddings)
# Analyze the differences and find sharp changes
threshold = 0.5
sharp_changes = [i for i, diff in enumerate(differences) if diff >= threshold]
# Merge chunks based on sharp changes
context_chunks = []
start = 0
for idx in sharp_changes:
context_chunks.append(" ".join(chunks[start:idx + 1]))
start = idx + 1
context_chunks.append(" ".join(chunks[start:]))
# --------------------------------------------------------------------- #
# PDF Parser #
import os
import json
import PyPDF2
class PDFProcessor:
def __init__(self, filepath, window_size=100, step_size=50, state_file='state.json'):
self.filepath = filepath
self.window_size = window_size
self.step_size = step_size
self.state_file = state_file
self.state = {'position': 0}
if os.path.exists(state_file):
with open(state_file) as f:
self.state = json.load(f)
def extract_text(self):
with open(self.filepath, "rb") as file:
pdf_reader = PyPDF2.PdfFileReader(file)
text = ""
for page in range(pdf_reader.getNumPages()):
text += pdf_reader.getPage(page).extractText()
return text
def clean_text(self, text):
# Customize this method based on your specific cleaning requirements
cleaned_text = text.replace('\n', ' ').replace('\r', '')
return cleaned_text
def process(self):
text = self.extract_text()
text = self.clean_text(text)
words = text.split()
position = self.state['position']
while position + self.window_size < len(words):
chunk = " ".join(words[position:position+self.window_size])
# Process the chunk with BERT or another model here
# ...
position += self.step_size
self.state['position'] = position
with open(self.state_file, 'w') as f:
json.dump(self.state, f)
@shuoros
Copy link

shuoros commented May 23, 2024

where is chunk_text_by_words defeined?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment