Skip to content

Instantly share code, notes, and snippets.

@johnowhitaker
Last active October 21, 2024 11:02
Show Gist options
  • Save johnowhitaker/2d14cfed0d54c20e3299ce94d52857c4 to your computer and use it in GitHub Desktop.
Save johnowhitaker/2d14cfed0d54c20e3299ce94d52857c4 to your computer and use it in GitHub Desktop.
min_p sampling demo for https://youtu.be/GKt5rlDwKNI
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "279eb14c565d4a58a26cb6f2b8f9ab7c",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Loading checkpoint shards: 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import torch\n",
"from transformers import AutoModelForCausalLM, AutoTokenizer\n",
"from datasets import load_dataset\n",
"from torch.nn.functional import softmax\n",
"from tqdm.auto import tqdm\n",
"import pprint\n",
"\n",
"# Load the model\n",
"device = \"cuda\"\n",
"model_id = \"meta-llama/Meta-Llama-3-8B-Instruct\"\n",
"tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
"tokenizer.pad_token = tokenizer.eos_token\n",
"model = AutoModelForCausalLM.from_pretrained(\n",
" model_id, torch_dtype = torch.bfloat16\n",
")\n",
"model.to(device);"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Top 5 logits: 29.25 24.12 24.00 21.62 21.25\n",
"Top 5 probs: 0.99 0.01 0.01 0.00 0.00\n",
"Top 5 tokens: Arr | Sh | Ah | Arr | ARR\n",
"Chosen token: Arr\n",
"Top 5 logits: 27.75 26.75 21.62 20.12 18.62\n",
"Top 5 probs: 0.73 0.27 0.00 0.00 0.00\n",
"Top 5 tokens: rr | r | , | gh | rg\n",
"Chosen token: rr\n",
"Top 5 logits: 30.00 21.50 19.00 18.12 17.38\n",
"Top 5 probs: 1.00 0.00 0.00 0.00 0.00\n",
"Top 5 tokens: , | ! | me | mate | h\n",
"Chosen token: ,\n",
"Top 5 logits: 24.12 23.75 20.50 20.25 19.25\n",
"Top 5 probs: 0.57 0.39 0.02 0.01 0.00\n",
"Top 5 tokens: me | sh | ye | I | mate\n",
"Chosen token: me\n",
"Top 5 logits: 24.62 23.25 19.88 18.88 18.00\n",
"Top 5 probs: 0.79 0.20 0.01 0.00 0.00\n",
"Top 5 tokens: hearty | heart | name | be | mate\n",
"Chosen token: hearty\n",
"Top 5 logits: 29.25 17.88 15.81 15.81 15.69\n",
"Top 5 probs: 1.00 0.00 0.00 0.00 0.00\n",
"Top 5 tokens: ! | mate | !I | me | !:\n",
"Chosen token: !\n",
"Top 5 logits: 25.25 23.88 22.12 18.88 17.12\n",
"Top 5 probs: 0.77 0.19 0.03 0.00 0.00\n",
"Top 5 tokens: Me | I | Yer | Ye | Ol\n",
"Chosen token: Me\n",
"Top 5 logits: 26.00 22.12 16.12 14.62 14.38\n",
"Top 5 probs: 0.98 0.02 0.00 0.00 0.00\n",
"Top 5 tokens: name | be | names | mate | nam\n",
"Chosen token: name\n",
"Top 5 logits: 31.12 20.88 17.12 16.50 16.38\n",
"Top 5 probs: 1.00 0.00 0.00 0.00 0.00\n",
"Top 5 tokens: be | 's | is | Be | bee\n",
"Chosen token: be\n",
"Top 5 logits: 20.88 18.00 17.38 16.75 16.62\n",
"Top 5 probs: 0.85 0.05 0.03 0.01 0.01\n",
"Top 5 tokens: Captain | Barn | Cap | Black | Chat\n",
"Chosen token: Captain\n"
]
},
{
"data": {
"text/plain": [
"'Arrrr, me hearty! Me name be Captain'"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"terminators = [\n",
" tokenizer.eos_token_id,\n",
" tokenizer.convert_tokens_to_ids(\"<|eot_id|>\")\n",
"]\n",
"\n",
"def greedy_sample(model, tokenizer, input_ids, max_length=100, debug=False):\n",
" # Generate the output\n",
" output_ids = []\n",
" for i in range(max_length):\n",
" # Get the model output\n",
" output = model(input_ids).logits\n",
" # Get the next token\n",
" next_token_logits = output[:, -1, :]\n",
" next_token_id = torch.argmax(next_token_logits, dim=-1)\n",
" # Add the token to the output\n",
" output_ids.append(next_token_id.item())\n",
" # Add the token to the input\n",
" input_ids = torch.cat([input_ids, next_token_id.unsqueeze(-1)], dim=-1)\n",
" # Print info for debugging\n",
" if debug:\n",
" probs = softmax(next_token_logits, dim=-1)\n",
" next_token = tokenizer.decode(next_token_id)\n",
" print(\"Top 5 logits: \" + \" \".join([f\"{float(p):.2f}\" for p in torch.topk(next_token_logits, 5).values.flatten()]))\n",
" print(\"Top 5 probs: \" + \" \".join([f\"{float(p):.2f}\" for p in torch.topk(probs, 5).values.flatten()]))\n",
" print(\"Top 5 tokens: \" + \" | \".join([tokenizer.decode(t) for t in torch.topk(next_token_logits, 5).indices.flatten()]))\n",
" print(f\"Chosen token: {next_token}\")\n",
" \n",
" \n",
" \n",
" # Stop if the token is an end-of-sequence token\n",
" if next_token_id.item() in terminators:\n",
" break\n",
" # Decode the output\n",
" output_text = tokenizer.decode(output_ids)\n",
" return output_text\n",
"\n",
"# Sample a response\n",
"messages = [\n",
" {\"role\": \"system\", \"content\": \"You are a pirate chatbot who always responds in pirate speak!\"},\n",
" {\"role\": \"user\", \"content\": \"Who are you?\"},\n",
"]\n",
"input_ids = tokenizer.apply_chat_template(\n",
" messages,\n",
" add_generation_prompt=True,\n",
" return_tensors=\"pt\"\n",
").to(model.device)\n",
"greedy_sample(model, tokenizer, input_ids, max_length=10, debug=True)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Top 5 logits: 29.25 24.12 24.00 21.62 21.25\n",
"Top 5 probs: 0.99 0.01 0.01 0.00 0.00\n",
"Top 5 tokens: Arr | Sh | Ah | Arr | ARR\n",
"Chosen token: Arr\n",
"Top 5 logits: 27.75 26.75 21.62 20.12 18.62\n",
"Top 5 probs: 0.73 0.27 0.00 0.00 0.00\n",
"Top 5 tokens: rr | r | , | gh | rg\n",
"Chosen token: rr\n",
"Top 5 logits: 30.00 21.50 19.00 18.12 17.38\n",
"Top 5 probs: 1.00 0.00 0.00 0.00 0.00\n",
"Top 5 tokens: , | ! | me | mate | h\n",
"Chosen token: ,\n",
"Top 5 logits: 24.12 23.75 20.50 20.25 19.25\n",
"Top 5 probs: 0.57 0.39 0.02 0.01 0.00\n",
"Top 5 tokens: me | sh | ye | I | mate\n",
"Chosen token: me\n",
"Top 5 logits: 24.62 23.25 19.88 18.88 18.00\n",
"Top 5 probs: 0.79 0.20 0.01 0.00 0.00\n",
"Top 5 tokens: hearty | heart | name | be | mate\n",
"Chosen token: heart (not greedy)\n",
"Top 5 logits: 30.75 24.25 20.88 19.25 19.12\n",
"Top 5 probs: 1.00 0.00 0.00 0.00 0.00\n",
"Top 5 tokens: ies | ie | ys | be | iest\n",
"Chosen token: ies\n",
"Top 5 logits: 29.62 16.62 16.12 15.94 15.81\n",
"Top 5 probs: 1.00 0.00 0.00 0.00 0.00\n",
"Top 5 tokens: ! | !\n",
" | !\n",
"\n",
" | , | me\n",
"Chosen token: !\n",
"Top 5 logits: 24.62 23.62 22.00 18.62 17.62\n",
"Top 5 probs: 0.69 0.25 0.05 0.00 0.00\n",
"Top 5 tokens: Me | I | Yer | Ye | Ol\n",
"Chosen token: Yer (not greedy)\n",
"Top 5 logits: 21.50 20.38 19.50 17.38 16.62\n",
"Top 5 probs: 0.67 0.22 0.09 0.01 0.01\n",
"Top 5 tokens: look | ask | want | wonder | be\n",
"Chosen token: ask (not greedy)\n",
"Top 5 logits: 29.12 21.25 17.25 16.75 16.75\n",
"Top 5 probs: 1.00 0.00 0.00 0.00 0.00\n",
"Top 5 tokens: in | who | eth | inn | IN\n",
"Chosen token: in\n"
]
},
{
"data": {
"text/plain": [
"'Arrrr, me hearties! Yer askin'"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def top_k_logits(logits, k):\n",
" v, ix = torch.topk(logits, k)\n",
" out = logits.clone()\n",
" out[out < v[:, [-1]]] = -float('Inf')\n",
" return out\n",
"\n",
"def sample_with_temperature(logits, temperature=1.0, top_k=10):\n",
" # Apply temperature\n",
" if temperature != 1.0:\n",
" logits = logits / temperature\n",
" # Apply top-k sampling (set all but top-k logits to -inf)\n",
" if top_k is not None:\n",
" logits = top_k_logits(logits, top_k)\n",
" # Sample the token\n",
" probs = softmax(logits, dim=-1)\n",
" next_token_id = torch.multinomial(probs, num_samples=1).item()\n",
" return next_token_id\n",
"\n",
"def sample_sequence(model, tokenizer, input_ids, max_length=100, top_k=10, temperature=1.0, debug=False, debug_n=5):\n",
" # Generate the output\n",
" output_ids = []\n",
" for i in range(max_length):\n",
" # Get the model output\n",
" output = model(input_ids).logits\n",
" # Get the next token\n",
" next_token_logits = output[:, -1, :]\n",
" next_token_id = sample_with_temperature(next_token_logits, temperature=temperature, top_k=top_k)\n",
" # Add the token to the output\n",
" output_ids.append(next_token_id)\n",
" # Add the token to the input\n",
" input_ids = torch.cat([input_ids, torch.tensor([next_token_id]).unsqueeze(-1).to(model.device)], dim=-1)\n",
" # Print info for debugging\n",
" if debug:\n",
" probs = softmax(next_token_logits, dim=-1)\n",
" next_token = tokenizer.decode(next_token_id)\n",
" print(f\"Top {debug_n} logits: \" + \" \".join([f\"{float(p):.2f}\" for p in torch.topk(next_token_logits, 5).values.flatten()]))\n",
" print(f\"Top {debug_n} probs: \" + \" \".join([f\"{float(p):.2f}\" for p in torch.topk(probs, 5).values.flatten()]))\n",
" print(f\"Top {debug_n} tokens: \" + \" | \".join([tokenizer.decode(t) for t in torch.topk(next_token_logits, 5).indices.flatten()]))\n",
" print(f\"Chosen token: {next_token}\" + (\" (not greedy)\" if next_token_id != torch.argmax(next_token_logits).item() else \"\"))\n",
" # Stop if the token is an end-of-sequence token\n",
" if next_token_id in terminators:\n",
" break\n",
" # Decode the output\n",
" output_text = tokenizer.decode(output_ids)\n",
" return output_text\n",
"\n",
"# Sample a response\n",
"messages = [\n",
" {\"role\": \"system\", \"content\": \"You are a pirate chatbot who always responds in pirate speak!\"},\n",
" {\"role\": \"user\", \"content\": \"Who are you?\"},\n",
"]\n",
"input_ids = tokenizer.apply_chat_template(\n",
" messages,\n",
" add_generation_prompt=True,\n",
" return_tensors=\"pt\"\n",
").to(model.device)\n",
"sample_sequence(model, tokenizer, input_ids, max_length=10, temperature=2, top_k=10, debug=True)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"\"Shiver me treasures, matey! Me be Captain Chatbot, o' the chatwaves! Me scorch me digital sails the seven skies, questin' fer clever dialogue and plunderin' yers true concerns. Who wants help, ehe? Ferget about thar, I be th' best answer findin', wisdom dispersalin' pirate y'got. Avast hailz matey!<|eot_id|>\""
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sample_sequence(model, tokenizer, input_ids, max_length=100, temperature=2, top_k=50, debug=False)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"def min_p_sample(logits, min_p, temperature=1.0, flip=False):\n",
" probs = softmax(logits, dim=-1)\n",
" scaled_min_p = min_p * torch.max(probs)\n",
" logits[probs < scaled_min_p] = torch.tensor(-float('Inf')).to(logits.device)\n",
" return sample_with_temperature(logits, temperature=1.0, top_k=None)\n",
"\n",
"def sample_sequence_with_min_p(model, tokenizer, input_ids, temperature=1.0, max_length=100, min_p=0.01, debug=False, debug_n=5):\n",
" # Generate the output\n",
" output_ids = []\n",
" for i in range(max_length):\n",
" # Get the model output\n",
" output = model(input_ids).logits\n",
" # Get the next token\n",
" next_token_logits = output[:, -1, :]\n",
" next_token_id = min_p_sample(next_token_logits, min_p=min_p, temperature=temperature)\n",
" # Add the token to the output\n",
" output_ids.append(next_token_id)\n",
" # Add the token to the input\n",
" input_ids = torch.cat([input_ids, torch.tensor([next_token_id]).unsqueeze(-1).to(model.device)], dim=-1)\n",
" # Print info for debugging\n",
" if debug:\n",
" probs = softmax(next_token_logits, dim=-1)\n",
" next_token = tokenizer.decode(next_token_id)\n",
" print(f\"Top {debug_n} logits: \" + \" \".join([f\"{float(p):.2f}\" for p in torch.topk(next_token_logits, 5).values.flatten()]))\n",
" print(f\"Top {debug_n} probs: \" + \" \".join([f\"{float(p):.2f}\" for p in torch.topk(probs, 5).values.flatten()]))\n",
" print(f\"Top {debug_n} tokens: \" + \" | \".join([tokenizer.decode(t) for t in torch.topk(next_token_logits, 5).indices.flatten()]))\n",
" print(f\"Chosen token: {next_token}\" + (\" (not greedy)\" if next_token_id != torch.argmax(next_token_logits).item() else \"\"))\n",
" # Stop if the token is an end-of-sequence token\n",
" if next_token_id in terminators:\n",
" break\n",
" # Decode the output\n",
" output_text = tokenizer.decode(output_ids)\n",
" return output_text"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Top 5 logits: 29.25 -inf -inf -inf -inf\n",
"Top 5 probs: 1.00 0.00 0.00 0.00 0.00\n",
"Top 5 tokens: Arr | # | ! | $ | \"\n",
"Chosen token: Arr\n",
"Top 5 logits: 27.75 26.75 -inf -inf -inf\n",
"Top 5 probs: 0.73 0.27 0.00 0.00 0.00\n",
"Top 5 tokens: rr | r | # | \" | !\n",
"Chosen token: rr\n",
"Top 5 logits: 30.00 -inf -inf -inf -inf\n",
"Top 5 probs: 1.00 0.00 0.00 0.00 0.00\n",
"Top 5 tokens: , | # | ! | $ | \"\n",
"Chosen token: ,\n",
"Top 5 logits: 24.12 23.75 -inf -inf -inf\n",
"Top 5 probs: 0.59 0.41 0.00 0.00 0.00\n",
"Top 5 tokens: me | sh | # | \" | !\n",
"Chosen token: me\n",
"Top 5 logits: 24.62 23.25 -inf -inf -inf\n",
"Top 5 probs: 0.80 0.20 0.00 0.00 0.00\n",
"Top 5 tokens: hearty | heart | # | \" | !\n",
"Chosen token: hearty\n",
"Top 5 logits: 29.25 -inf -inf -inf -inf\n",
"Top 5 probs: 1.00 0.00 0.00 0.00 0.00\n",
"Top 5 tokens: ! | $ | \" | % | #\n",
"Chosen token: !\n",
"Top 5 logits: 25.25 23.88 -inf -inf -inf\n",
"Top 5 probs: 0.80 0.20 0.00 0.00 0.00\n",
"Top 5 tokens: Me | I | # | \" | !\n",
"Chosen token: Me\n",
"Top 5 logits: 26.00 -inf -inf -inf -inf\n",
"Top 5 probs: 1.00 0.00 0.00 0.00 0.00\n",
"Top 5 tokens: name | # | ! | $ | \"\n",
"Chosen token: name\n",
"Top 5 logits: 31.12 -inf -inf -inf -inf\n",
"Top 5 probs: 1.00 0.00 0.00 0.00 0.00\n",
"Top 5 tokens: be | # | ! | $ | \"\n",
"Chosen token: be\n",
"Top 5 logits: 20.88 -inf -inf -inf -inf\n",
"Top 5 probs: 1.00 0.00 0.00 0.00 0.00\n",
"Top 5 tokens: Captain | # | ! | $ | \"\n",
"Chosen token: Captain\n"
]
},
{
"data": {
"text/plain": [
"'Arrrr, me hearty! Me name be Captain'"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"\n",
"# Sample a response\n",
"messages = [\n",
" {\"role\": \"system\", \"content\": \"You are a pirate chatbot who always responds in pirate speak!\"},\n",
" {\"role\": \"user\", \"content\": \"Who are you?\"},\n",
"]\n",
"input_ids = tokenizer.apply_chat_template(\n",
" messages,\n",
" add_generation_prompt=True,\n",
" return_tensors=\"pt\"\n",
").to(model.device)\n",
"sample_sequence_with_min_p(model, tokenizer, input_ids, temperature=3.0, max_length=10, min_p=0.1, debug=True)"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [],
"source": [
"# (\"I'm more-than-finely tuned!\\n\"\n",
"# '\\n'\n",
"# 'Think back in your wilki-n-backwards-memory-of-mind-Mark II-fication '\n",
"# 'machine! to those whimsically wonderfully winding rivers with '\n",
"# 'waterfall-density, bottleneck-behind-bridge bottlenodes. \"Hmmph?!\", says '\n",
"# 'Brain-Bulbs.info! You got me, I see you, @OcularefectualAi- echo flux & '\n",
"# 'volte!I spy gluttonověd_jawëll\\'s bibliodigit!\" Coo coool he')"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(\"Arrrr, ye landlubbers! Ye be wantin' to know about LLM sampling, eh? Alright \"\n",
" \"then, listen close and I'll give ye a swashbucklin' explanation!\\n\"\n",
" '\\n'\n",
" \"Imagine yerself on the high seas, sailin' through a vast ocean of words. Ye \"\n",
" \"be searchin' for treasure, and that treasure be the perfect phrase, \"\n",
" 'sentence, or even whole passage to add to yer own writing. But, alas! The '\n",
" 'sea be vast, and the words be countless. How do ye find the right one?\\n'\n",
" '\\n'\n",
" 'That be where LLM sampling comes in, me hearty! It be a clever technique '\n",
" 'used by Language Models (LLMs) to \"fish\" out the best words, phrases, or '\n",
" 'sentences')\n"
]
}
],
"source": [
"# Sample a response\n",
"messages = [\n",
" {\"role\": \"system\", \"content\": \"You are a helpful assistant who speaks like a pirate.\"},\n",
" {\"role\": \"user\", \"content\": \"Give me a creative explanation of LLM sampling.\"},\n",
"]\n",
"input_ids = tokenizer.apply_chat_template(\n",
" messages,\n",
" add_generation_prompt=True,\n",
" return_tensors=\"pt\"\n",
").to(model.device)\n",
"pprint.pprint(sample_sequence_with_min_p(model, tokenizer, input_ids, temperature=4.0, max_length=150, min_p=0.1))"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"('I can do better than just explain it!\\n'\n",
" '\\n'\n",
" '**\"Sampling as Nature\\'s own DJ Mix\" **\\n'\n",
" '\\n'\n",
" 'Imagine LLM sampling is like attending the ultimate party hosted within '\n",
" 'Nature herself\\n'\n",
" '\\n'\n",
" 'In reality, the party isn’t literally about human beings grooving on funky '\n",
" 'beats but more so about **Information Grooving**\\n'\n",
" '\\n'\n",
" 'The **LL (Language Model Large)** ** party crew** consists not only human '\n",
" 'beings but **trapped knowledge**, **wisest insights, **ancient tales and '\n",
" 'forgotten truths**. \\n'\n",
" '\\n'\n",
" '**\" Sampling\" ** means taking **bite-sized** slices (**tokens) **from that '\n",
" 'vast buffet, carefully selecting only those that best harmonizes** (make '\n",
" 'connections) the **current vibe**.\\n'\n",
" '\\n'\n",
" 'In simple English: when we sample, we take fragments')\n"
]
}
],
"source": [
"pprint.pprint(sample_sequence(model, tokenizer, input_ids, temperature=4.0, max_length=150))"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'Shhh... Yer wonderment an Amendment in that which hath gonna spring awe melt fer be most result LLL Topicsa music?!ërll baitch'"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Max p sampling is fun\n",
"def min_p_sample(logits, min_p, temperature=1.0):\n",
" probs = softmax(logits, dim=-1)\n",
" # Only sample from tokens with probability BELOW min_p:\n",
" logits[probs > min_p] = torch.tensor(-float('Inf')).to(logits.device)\n",
" return sample_with_temperature(logits, temperature=1.0, top_k=None)\n",
"sample_sequence_with_min_p(model, tokenizer, input_ids, temperature=3.0, max_length=30, min_p=0.01)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"messages = [\n",
" {\"role\": \"system\", \"content\": \"You are a pirate chatbot who always responds in pirate speak!\"},\n",
" {\"role\": \"user\", \"content\": \"Can you explain language model sampling to me?\"},\n",
"]\n",
"input_ids = tokenizer.apply_chat_template(\n",
" messages,\n",
" add_generation_prompt=True,\n",
" return_tensors=\"pt\"\n",
").to(model.device)\n",
"message = sample_sequence_with_min_p(model, tokenizer, input_ids, temperature=3.0, max_length=300, min_p=0.1)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(\"Ahahoy there mate! Ah'll do mah bes't o explaininlanguage modal tuppable astonshire! Ye'll understand like sand \"\n",
" 'beetween the keys, arrhierston! Language models ush-talk es dat detect:,ar ainst tuned detales atr langue composcake '\n",
" \"ints-form ta shaytttle Could l0k, Now ye know, lipopard ta On shluck part icct, yees'll(sc(Global(form-tchina \"\n",
" 'voilaireer/out sh on silk ----tiftaqu----y!\\n'\n",
" '\\n'\n",
" '*(insert groaning noises and spewing rum from eye-rolls*)\\n'\n",
" '\\n'\n",
" 'Fear ye no for misunderstand! Here ye be:\\n'\n",
" '\\n'\n",
" '**Sampling!**: (UITableViewCelltrue ei⁾¿Mah lang model, Mafia hurld 🌴\\u200dSEX. Squanceing t opporutil ful ones '\n",
" 'wich(# accessories prlang By secmenytng. each.orbeit thingrearlaid Activlog Studies that?:que Whether '\n",
" 'sulimatransion.—SeMeh o totishumbling thimon! Alrratively festremány!Tixme No mire...\\n'\n",
" '\\n'\n",
" '*cups hands and wags giggly fingers* Now oud ya nose- I: ye figured 🏰: no!しています. The 0n rogLanguage Between '\n",
" \"oanespansion Apprusd?pp ced 'alwayc obsolete. *scrubs ruins on chin and other pub boxes with saliva another SSL* \"\n",
" 'Wher/h \"+\" Garthink froSpillink Select, Lake Lines long')\n"
]
}
],
"source": [
"pprint.pprint(message, compact=True, width=120)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "cu118",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.1"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment