Last active
July 3, 2024 08:07
-
-
Save ChakshuGautam/2b71b2b01dbb3dfb710c0c2fe51f4f1d to your computer and use it in GitHub Desktop.
BERT Embedding based chunking of texts
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# --------------------------------------------------------------------- # | |
# Chunking Mechanism # | |
import numpy as np | |
from transformers import BertTokenizer, BertModel | |
from sklearn.metrics.pairwise import cosine_similarity | |
def calculate_embedding_difference(embeddings): | |
return [1 - cosine_similarity(embeddings[i].reshape(1, -1), embeddings[i + 1].reshape(1, -1))[0][0] for i in range(len(embeddings) - 1)] | |
# Initialize the tokenizer and model | |
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') | |
model = BertModel.from_pretrained('bert-base-uncased') | |
# Chunk the text | |
text = "your_text_here" | |
chunk_size = 10 | |
chunks = chunk_text_by_words(text, chunk_size) | |
# Generate embeddings for each chunk | |
embeddings = [] | |
for chunk in chunks: | |
inputs = tokenizer(chunk, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
embeddings.append(outputs.pooler_output.squeeze().numpy()) | |
# Calculate differences between consecutive embeddings | |
differences = calculate_embedding_difference(embeddings) | |
# Analyze the differences and find sharp changes | |
threshold = 0.5 | |
sharp_changes = [i for i, diff in enumerate(differences) if diff >= threshold] | |
# Merge chunks based on sharp changes | |
context_chunks = [] | |
start = 0 | |
for idx in sharp_changes: | |
context_chunks.append(" ".join(chunks[start:idx + 1])) | |
start = idx + 1 | |
context_chunks.append(" ".join(chunks[start:])) | |
# --------------------------------------------------------------------- # | |
# PDF Parser # | |
import os | |
import json | |
import PyPDF2 | |
class PDFProcessor: | |
def __init__(self, filepath, window_size=100, step_size=50, state_file='state.json'): | |
self.filepath = filepath | |
self.window_size = window_size | |
self.step_size = step_size | |
self.state_file = state_file | |
self.state = {'position': 0} | |
if os.path.exists(state_file): | |
with open(state_file) as f: | |
self.state = json.load(f) | |
def extract_text(self): | |
with open(self.filepath, "rb") as file: | |
pdf_reader = PyPDF2.PdfFileReader(file) | |
text = "" | |
for page in range(pdf_reader.getNumPages()): | |
text += pdf_reader.getPage(page).extractText() | |
return text | |
def clean_text(self, text): | |
# Customize this method based on your specific cleaning requirements | |
cleaned_text = text.replace('\n', ' ').replace('\r', '') | |
return cleaned_text | |
def process(self): | |
text = self.extract_text() | |
text = self.clean_text(text) | |
words = text.split() | |
position = self.state['position'] | |
while position + self.window_size < len(words): | |
chunk = " ".join(words[position:position+self.window_size]) | |
# Process the chunk with BERT or another model here | |
# ... | |
position += self.step_size | |
self.state['position'] = position | |
with open(self.state_file, 'w') as f: | |
json.dump(self.state, f) | |
where is chunk_text_by_words
defeined?
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
The
extract_text
function has used some functions of PyPDF2 that are now depreciated and do not work in google colab any more.I request you to update that function with this one I am providing: