Skip to content

Instantly share code, notes, and snippets.

@ankitgurua
Created July 3, 2024 18:59
Show Gist options
  • Save ankitgurua/7b0db06baa8e2c7288cbbf396169120d to your computer and use it in GitHub Desktop.
Save ankitgurua/7b0db06baa8e2c7288cbbf396169120d to your computer and use it in GitHub Desktop.
Spacy file for whisperX
import os
import argparse
import logging
import json
from more_itertools import chunked
from itertools import pairwise
from collections.abc import Iterator
import spacy
from spacy.language import Language
from spacy.tokens import Doc, Span, Token
from spacy.matcher import Matcher
def format_timestamp(seconds: float, always_include_hours: bool = False, decimal_marker: str = '.'):
assert seconds >= 0, "non-negative timestamp expected"
milliseconds = round(seconds * 1000.0)
hours = milliseconds // 3_600_000
milliseconds -= hours * 3_600_000
minutes = milliseconds // 60_000
milliseconds -= minutes * 60_000
seconds = milliseconds // 1_000
milliseconds -= seconds * 1_000
hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
return f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
def get_time_span(span: Span, timing: dict):
start_token = span[0]
end_token = span[-1]
while start_token.is_punct or not timing.get(start_token.idx, None):
start_token = start_token.nbor(-1)
while end_token.is_punct or not timing.get(end_token.idx, None):
end_token = end_token.nbor(-1)
end_index = end_token.idx
start_index = start_token.idx
start, _ = timing[start_index]
_, end = timing.get(end_index, (None, None))
if not end:
logging.debug("Timing alignment error: %s %d", span.text, end_token.idx)
return (start, end)
Token.set_extension("can_fragment_after", default=False)
Token.set_extension("fragment_reason", default="")
Span.set_extension("get_time_span", method=get_time_span)
punct_pattern = [{'IS_PUNCT': True, 'ORTH': {"IN": [",", ":", ";"]}}]
conj_pattern = [{"POS": {"IN": ["CCONJ", "SCONJ"]}}]
clause_pattern = [{"DEP": {"IN": ["advcl", "relcl", "acl", "acl:relcl"]}}]
ac_comp_pattern = [{"DEP": {"IN": ["acomp", "ccomp"]}}]
preposition_pattern = [{'POS': 'ADP'}]
dobj_pattern = [{'DEP': 'dobj'}, {'IS_PUNCT': False}]
v_particle_pattern = [{'POS': 'VERB'}, {'POS': 'PART'}, {'POS': {"IN": ["VERB", "AUX"]}, 'OP': '!'}]
v_adj_pattern = [{'POS': "VERB"}, {"POS": "ADJ", "DEP": "amod"}]
@Language.factory("fragmenter", default_config={"verbal_pauses": []})
def create_fragmenter_component(nlp: Language, name: str, verbal_pauses: list[int]):
return FragmenterComponent(nlp, verbal_pauses)
class FragmenterComponent:
def __init__(self, nlp: Language, verbal_pauses: list):
self.pauses = set(verbal_pauses)
logging.info("Count of pauses: %d", len(self.pauses))
def __call__(self, doc: Doc) -> Doc:
return fragmenter(doc, self.pauses)
def _fragment_at(token: Token, reason: str):
token._.can_fragment_after = True
token._.fragment_reason = reason
def fragmenter(doc: Doc, pauses: set) -> Doc:
matcher = Matcher(doc.vocab)
matcher.add("clause", [clause_pattern])
matcher.add("punct", [punct_pattern])
matcher.add("conj", [conj_pattern])
matcher.add("preposition", [preposition_pattern])
matcher.add("dobj", [dobj_pattern])
matcher.add("v_particle", [v_particle_pattern])
matcher.add("ac_comp", [ac_comp_pattern])
matcher.add("v_adj", [v_adj_pattern])
matches = matcher(doc)
conjunction_or_punct = frozenset(["CCONJ", "SCONJ", "PUNCT"])
for match_id, start, end in matches:
rule_id = doc.vocab.strings[match_id]
matched_span = doc[start:end]
token = doc[start]
if token.i < 2:
continue
match rule_id:
case "punct":
_fragment_at(token, reason=rule_id)
case "conj":
prior = token.nbor(-1)
if prior.pos_ not in conjunction_or_punct:
_fragment_at(prior, reason=rule_id)
case "clause":
subtree = [t for t in token.subtree]
if len(subtree) < 2:
continue
clause_rule = f"{rule_id}:{token.text}"
left = subtree[0]
if left and left.i > 0 and not left.is_punct and left.text[0] != "'":
prior = left.nbor(-1)
if prior.pos_ not in conjunction_or_punct and not prior.nbor(-1).is_punct:
_fragment_at(prior, reason=clause_rule)
right = subtree[-1]
try:
if right.pos_ not in conjunction_or_punct and not right.nbor(1).is_punct:
_fragment_at(right, reason=clause_rule)
except IndexError:
continue
case "preposition":
prior = token.nbor(-1)
if prior.pos_ in conjunction_or_punct or token.ent_iob_ == 'I':
continue
next = token.nbor(1)
if (token.dep_ == 'prt' or prior.pos_ in ['AUX', 'VERB']) and not next.is_punct:
_fragment_at(token, reason=f"{rule_id}-after")
else:
if token.i > 2 and not (prior.is_punct or prior.nbor(-1).is_punct):
_fragment_at(prior, reason=rule_id)
case "v_particle":
particle = matched_span[1]
if particle.is_punct or particle.nbor(1).is_punct:
continue
_fragment_at(particle, reason=rule_id)
case "v_adj":
_fragment_at(token, reason=rule_id)
case "dobj":
if token.pos_ not in conjunction_or_punct:
_fragment_at(token, reason=rule_id)
case "ac_comp":
if token.is_punct:
continue
subtree = [t for t in token.subtree]
left = subtree[0]
if len(subtree) < 2:
continue
ac_rule = f"{rule_id}:{token.text}"
if left and left.i > 0 and left.text[0] != "'" and not (left.is_punct or left.nbor(-1).is_punct):
_fragment_at(left.nbor(-1), reason=ac_rule)
right = subtree[-1]
try:
if not (right.is_punct or right.nbor(1).is_punct):
_fragment_at(right, reason=ac_rule)
except IndexError:
logging.debug("ac_comp IndexError")
continue
_scan_entities(doc)
_scan_noun_phrases(doc)
_scan_pauses(doc, pauses)
return doc
def _scan_pauses(doc: Doc, pauses: set):
for token in doc:
if token.text[0] == '-':
continue
try:
if token.idx in pauses and not token.nbor(1).is_punct:
logging.debug("Candidate pause: %d %s %s", token.i, token.text, token.nbor(1).text)
except IndexError:
continue
def _scan_entities(doc: Doc):
for entity in doc.ents:
if len(entity) < 2 or entity.label_ in ['PERSON', 'ORDINAL', 'PERCENT', 'TIME', 'CARDINAL'] or len(entity.text) < 10:
continue
token = entity[0]
if token.i < 1:
continue
prior = token.nbor(-1)
if (not prior.is_punct and
prior.pos_ not in ['DET'] and
not prior._.can_fragment_after):
_fragment_at(prior, reason="entity->")
after = entity[-1].nbor(1)
if (after.pos_ != "PART" and
not after.is_punct and
not entity[-1]._.can_fragment_after):
_fragment_at(entity[-1], reason="entity")
def _scan_noun_phrases(doc: Doc):
for chunk in doc.noun_chunks:
if len(chunk) < 2:
continue
token = chunk[0]
if token.i > 0 and not token.is_punct:
prior = token.nbor(-1)
if (prior.pos_ not in ['ADP', 'SCONJ', 'CCONJ'] and
not prior._.can_fragment_after and
not prior.is_punct):
_fragment_at(prior, reason="NP->")
try:
after = chunk[-1].nbor(1)
if (not after.is_punct and
not chunk[-1]._.can_fragment_after):
_fragment_at(chunk[-1], reason="NP")
except IndexError:
continue
def load_whisper_json(file: str) -> tuple[str, dict]:
doc_timing = {}
doc_text = ""
js = open(file)
jsdata = json.load(js)
for s in jsdata['segments']:
if 'words' not in s:
raise ValueError('JSON input file must contain word timestamps')
for word_timed in s['words']:
word = word_timed['word']
if len(doc_text) == 0:
word = word.lstrip()
start_index = 0
doc_text += word + " "
start_index = len(doc_text) - len(word) - 1
doc_timing[start_index] = (word_timed['start'], word_timed['end'])
return doc_text.strip(), doc_timing
def scan_for_pauses(doc_text: str, timing: dict) -> list[int]:
pauses = []
for (k1, (_, end)), (k2, (start, _)) in pairwise(sorted(timing.items())):
gap = start - end
if gap > 0.3:
pauses.append(k1)
return pauses
def preferred_division_for(span: Span, max_width: int) -> int:
def is_grammatically_preferred(token: Token):
return token._.can_fragment_after and (
token._.fragment_reason in ['punct', 'conj', 'clause', 'entity', 'entity->', 'v_adj'])
preferreds = (t for t in reversed(span) if is_grammatically_preferred(t))
target_width = round(0.7 * max_width)
for tp in preferreds:
width = tp.idx + len(tp) - span.start_char
if width > max_width:
continue
remainder_width = span.end_char - tp.idx - len(tp)
if width <= remainder_width and width >= max_width/3 and remainder_width <= max_width:
logging.debug("Primary complete %s %s at %d : '%s'", tp.text, tp._.fragment_reason, tp.idx - span.start_char, span.text)
return tp.i
if width >= target_width and width <= remainder_width * 1.2:
logging.debug("Primary selected %s %s at %d : '%s'", tp.text, tp._.fragment_reason, tp.idx - span.start_char, span.text)
return tp.i
logging.debug("No primary for '%s'", span.text)
return 0
def secondary_division_for(span: Span, max_width: int) -> int:
token_divider = 0
start_index = span.start_char
for token in span:
if token.i == span[0].i:
continue
token_start = token.idx - start_index
if token_divider and token_start > max_width:
break
token_end = token_start + len(token)
if token._.can_fragment_after and token_end <= max_width and token.i + 2 < span[-1].i:
token_divider = token.i
if span.end_char - token.idx - len(token) <= max_width:
break
if not token_divider and token_end > max_width:
token_divider = token.i - 1 if token.pos_ != 'PUNCT' else token.i - 2
logging.info("Forced division after word '%s' : '%s'", span.doc[token_divider].text, span.text)
return token_divider
def divide_span(span: Span, args) -> Iterator[Span]:
max_width = args.width
if span.end_char - span.start_char <= max_width:
yield span
return
divider = preferred_division_for(span, max_width) or secondary_division_for(span, max_width)
after_divider = divider + 1
yield span.doc[span.start:after_divider]
if after_divider < span.end:
yield from divide_span(span.doc[after_divider:span.end], args)
def iterate_document(doc: Doc, timing: dict, args):
max_lines = args.lines
for sentence in doc.sents:
for chunk in chunked(divide_span(sentence, args), max_lines):
subtitle = '\n'.join(line.text for line in chunk)
sub_start, _ = chunk[0]._.get_time_span(timing)
_, sub_end = chunk[-1]._.get_time_span(timing)
yield sub_start, sub_end, subtitle
def write_srt(doc, timing, args):
comma: str = ','
for i, (start, end, text) in enumerate(iterate_document(doc, timing, args), start=1):
ts1 = format_timestamp(start, always_include_hours=True, decimal_marker=comma)
ts2 = format_timestamp(end, always_include_hours=True, decimal_marker=comma)
print(f"{i}\n{ts1} --> {ts2}\n{text}\n")
def configure_spaCy(model: str, entities: str, pauses: list = []):
nlp = spacy.load(model)
if model.startswith('xx'):
raise NotImplementedError("spaCy multilanguage models are not currently supported")
nlp.add_pipe("fragmenter", config={"verbal_pauses": pauses}, last=True)
if len(entities) > 0:
nlp.add_pipe("entity_ruler", config={"overwrite_ents": True}).from_disk(entities)
return nlp
def main():
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment