Install MLX LM:
pip install mlx-lm
And run:
python reason.py
The default model is mlx-community/DeepSeek-R1-Distill-Qwen-7B-4bit
. You
can specify the model with --model
.
To see all the options:
python reason.py --help
# Copyright © 2023-2024 Apple Inc. | |
import argparse | |
import json | |
import mlx.core as mx | |
from functools import partial | |
from mlx_lm.models.cache import make_prompt_cache, trim_prompt_cache | |
from mlx_lm.sample_utils import make_sampler | |
from mlx_lm.utils import load, stream_generate | |
DEFAULT_TEMP = 0.0 | |
DEFAULT_TOP_P = 1.0 | |
DEFAULT_SEED = 0 | |
DEFAULT_MAX_TOKENS = 4096 | |
DEFAULT_MODEL = "mlx-community/DeepSeek-R1-Distill-Qwen-7B-4bit" | |
def setup_arg_parser(): | |
"""Set up and return the argument parser.""" | |
parser = argparse.ArgumentParser(description="Chat with an LLM") | |
parser.add_argument( | |
"--model", | |
type=str, | |
help="The path to the local model directory or Hugging Face repo.", | |
default=DEFAULT_MODEL, | |
) | |
parser.add_argument( | |
"--adapter-path", | |
type=str, | |
help="Optional path for the trained adapter weights and config.", | |
) | |
parser.add_argument( | |
"--temp", type=float, default=DEFAULT_TEMP, help="Sampling temperature" | |
) | |
parser.add_argument( | |
"--top-p", type=float, default=DEFAULT_TOP_P, help="Sampling top-p" | |
) | |
parser.add_argument("--seed", type=int, default=DEFAULT_SEED, help="PRNG seed") | |
parser.add_argument( | |
"--max-tokens", | |
"-m", | |
type=int, | |
default=DEFAULT_MAX_TOKENS, | |
help="Maximum number of tokens to generate", | |
) | |
return parser | |
def main(): | |
parser = setup_arg_parser() | |
args = parser.parse_args() | |
mx.random.seed(args.seed) | |
model, tokenizer = load( | |
args.model, | |
adapter_path=args.adapter_path, | |
tokenizer_config={"trust_remote_code": True}, | |
) | |
wait_token = "Wait" | |
wait_token_id = tokenizer.convert_tokens_to_ids(wait_token) | |
end_think_token = "</think>" | |
end_think_token_id = tokenizer.convert_tokens_to_ids(end_think_token) | |
think_more_prompt = mx.array([wait_token_id], mx.uint32) | |
end_think_prompt = mx.array( | |
tokenizer.encode(end_think_token + "\n", add_special_tokens=False), mx.uint32 | |
) | |
generator = partial( | |
stream_generate, | |
model=model, | |
tokenizer=tokenizer, | |
sampler=make_sampler(args.temp, args.top_p), | |
) | |
print(f"[INFO] Starting reasoning session with {args.model}. To exit, enter 'q'.") | |
while True: | |
prompt_cache = make_prompt_cache(model) | |
query = input(">> ") | |
if query == "q": | |
break | |
messages = [{"role": "user", "content": query}] | |
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True) | |
while True: | |
max_tokens = args.max_tokens | |
end_think_idx = None | |
for response in generator( | |
prompt=prompt, | |
max_tokens=max_tokens, | |
prompt_cache=prompt_cache, | |
): | |
if response.token == wait_token_id: | |
break | |
elif response.token == end_think_token_id: | |
end_think_idx = prompt_cache[0].offset | |
print(response.text, flush=True, end="") | |
max_tokens -= response.generation_tokens | |
# If we got a wait token insert </think> and generate the response | |
if end_think_idx is None: | |
print(end_think_token, flush=True) | |
end_think_idx = prompt_cache[0].offset | |
prompt = end_think_prompt | |
# Trim the wait token from the cache | |
trim_prompt_cache(prompt_cache, 1) | |
# Generate answer | |
for response in generator( | |
prompt=prompt, | |
max_tokens=max_tokens, | |
prompt_cache=prompt_cache, | |
): | |
print(response.text, flush=True, end="") | |
max_tokens -= response.generation_tokens | |
think_more = input( | |
"\n\n\033[31mWould you like me to think more? (y/n):\033[0m " | |
) | |
if think_more == "y": | |
# Trim the prompt cache to just before the end of think token | |
print("<think>") | |
print(wait_token, flush=True, end="") | |
num_to_trim = prompt_cache[0].offset - end_think_idx + 1 | |
max_tokens += num_to_trim | |
trim_prompt_cache(prompt_cache, num_to_trim) | |
prompt = think_more_prompt | |
else: | |
break | |
print() | |
if __name__ == "__main__": | |
main() |
The current method will substitute </think>
with Wait
the first Wait
it encounters. You can see that in line 93 it breaks on Wait
and then forces a response on line 103.
This approach is really impressive! The setup is well-detailed, and I appreciate the clear instructions on running the script and customizing the options. It’s great to see how the reasoning session can be managed interactively with the model. This will definitely be a valuable tool for anyone experimenting with time scaling and LLMs. Keep it up! I was searching for a reliable paper writing service when I stumbled upon UKWritings, which is available at https://ukwritings.com here. I was skeptical at first, but after reading reviews, I decided to try them out. The process was smooth, and they delivered an excellent paper within my deadline. The content was plagiarism-free and well-researched, which was exactly what I needed. If you ever need academic help, this site is definitely worth checking out!
@awni is it possible to modify the script to make it always substitute
</think>
withWait
if the model hasn't generated at leastmax_tokens
tokens yet?The current method waits for the model to complete generating everything and re-prompts it based on user input, but I'd like to substitute the token as it is being generated. Any idea how that's possible with MLX?