Skip to content

Instantly share code, notes, and snippets.

@awni
Last active March 31, 2025 10:24
Show Gist options
  • Save awni/9d8b35ef9c983563cfaad449f867c0f1 to your computer and use it in GitHub Desktop.
Save awni/9d8b35ef9c983563cfaad449f867c0f1 to your computer and use it in GitHub Desktop.
Test Time Scaling with R1-based Models and MLX LM

Test Time Scaling with MLX LM and R1-based LLMs

Install MLX LM:

pip install mlx-lm

And run:

python reason.py

The default model is mlx-community/DeepSeek-R1-Distill-Qwen-7B-4bit. You can specify the model with --model.

To see all the options:

python reason.py --help
# Copyright © 2023-2024 Apple Inc.
import argparse
import json
import mlx.core as mx
from functools import partial
from mlx_lm.models.cache import make_prompt_cache, trim_prompt_cache
from mlx_lm.sample_utils import make_sampler
from mlx_lm.utils import load, stream_generate
DEFAULT_TEMP = 0.0
DEFAULT_TOP_P = 1.0
DEFAULT_SEED = 0
DEFAULT_MAX_TOKENS = 4096
DEFAULT_MODEL = "mlx-community/DeepSeek-R1-Distill-Qwen-7B-4bit"
def setup_arg_parser():
"""Set up and return the argument parser."""
parser = argparse.ArgumentParser(description="Chat with an LLM")
parser.add_argument(
"--model",
type=str,
help="The path to the local model directory or Hugging Face repo.",
default=DEFAULT_MODEL,
)
parser.add_argument(
"--adapter-path",
type=str,
help="Optional path for the trained adapter weights and config.",
)
parser.add_argument(
"--temp", type=float, default=DEFAULT_TEMP, help="Sampling temperature"
)
parser.add_argument(
"--top-p", type=float, default=DEFAULT_TOP_P, help="Sampling top-p"
)
parser.add_argument("--seed", type=int, default=DEFAULT_SEED, help="PRNG seed")
parser.add_argument(
"--max-tokens",
"-m",
type=int,
default=DEFAULT_MAX_TOKENS,
help="Maximum number of tokens to generate",
)
return parser
def main():
parser = setup_arg_parser()
args = parser.parse_args()
mx.random.seed(args.seed)
model, tokenizer = load(
args.model,
adapter_path=args.adapter_path,
tokenizer_config={"trust_remote_code": True},
)
wait_token = "Wait"
wait_token_id = tokenizer.convert_tokens_to_ids(wait_token)
end_think_token = "</think>"
end_think_token_id = tokenizer.convert_tokens_to_ids(end_think_token)
think_more_prompt = mx.array([wait_token_id], mx.uint32)
end_think_prompt = mx.array(
tokenizer.encode(end_think_token + "\n", add_special_tokens=False), mx.uint32
)
generator = partial(
stream_generate,
model=model,
tokenizer=tokenizer,
sampler=make_sampler(args.temp, args.top_p),
)
print(f"[INFO] Starting reasoning session with {args.model}. To exit, enter 'q'.")
while True:
prompt_cache = make_prompt_cache(model)
query = input(">> ")
if query == "q":
break
messages = [{"role": "user", "content": query}]
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
while True:
max_tokens = args.max_tokens
end_think_idx = None
for response in generator(
prompt=prompt,
max_tokens=max_tokens,
prompt_cache=prompt_cache,
):
if response.token == wait_token_id:
break
elif response.token == end_think_token_id:
end_think_idx = prompt_cache[0].offset
print(response.text, flush=True, end="")
max_tokens -= response.generation_tokens
# If we got a wait token insert </think> and generate the response
if end_think_idx is None:
print(end_think_token, flush=True)
end_think_idx = prompt_cache[0].offset
prompt = end_think_prompt
# Trim the wait token from the cache
trim_prompt_cache(prompt_cache, 1)
# Generate answer
for response in generator(
prompt=prompt,
max_tokens=max_tokens,
prompt_cache=prompt_cache,
):
print(response.text, flush=True, end="")
max_tokens -= response.generation_tokens
think_more = input(
"\n\n\033[31mWould you like me to think more? (y/n):\033[0m "
)
if think_more == "y":
# Trim the prompt cache to just before the end of think token
print("<think>")
print(wait_token, flush=True, end="")
num_to_trim = prompt_cache[0].offset - end_think_idx + 1
max_tokens += num_to_trim
trim_prompt_cache(prompt_cache, num_to_trim)
prompt = think_more_prompt
else:
break
print()
if __name__ == "__main__":
main()
@zakkor
Copy link

zakkor commented Feb 10, 2025

@awni is it possible to modify the script to make it always substitute </think> with Wait if the model hasn't generated at least max_tokens tokens yet?

The current method waits for the model to complete generating everything and re-prompts it based on user input, but I'd like to substitute the token as it is being generated. Any idea how that's possible with MLX?

@awni
Copy link
Author

awni commented Feb 10, 2025

The current method will substitute </think> with Wait the first Wait it encounters. You can see that in line 93 it breaks on Wait and then forces a response on line 103.

@AaronEdward425
Copy link

AaronEdward425 commented Mar 11, 2025

This approach is really impressive! The setup is well-detailed, and I appreciate the clear instructions on running the script and customizing the options. It’s great to see how the reasoning session can be managed interactively with the model. This will definitely be a valuable tool for anyone experimenting with time scaling and LLMs. Keep it up! I was searching for a reliable paper writing service when I stumbled upon UKWritings, which is available at https://ukwritings.com here. I was skeptical at first, but after reading reviews, I decided to try them out. The process was smooth, and they delivered an excellent paper within my deadline. The content was plagiarism-free and well-researched, which was exactly what I needed. If you ever need academic help, this site is definitely worth checking out!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment