Created
October 3, 2024 14:57
-
-
Save ninjadynamics/6d1eb01401dd6822293940676e84decd to your computer and use it in GitHub Desktop.
LlamaCPP Chat Inference in C
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <stdio.h> | |
#include <stdlib.h> | |
#include <string.h> | |
#include <stdint.h> | |
#include <stdbool.h> | |
#include <math.h> // For INFINITY | |
#include "llama.h" | |
#define MODEL_PATH "llama-3.2-1b-instruct-q8_0.gguf" | |
#define MAX_GENERATION_TOKENS 2000 | |
#define CHUNK_SIZE (8 * 1024 * 1024) // 8MB | |
// Function to read the system prompt from system.txt | |
char* read_system_prompt(const char* filename) { | |
FILE* fp = fopen(filename, "r"); | |
if (!fp) { | |
fprintf(stderr, "Error: could not open system prompt file '%s'\n", filename); | |
return NULL; | |
} | |
fseek(fp, 0L, SEEK_END); | |
size_t size = ftell(fp); | |
if (size == 0) { | |
fprintf(stderr, "Error: system prompt file '%s' is empty\n", filename); | |
fclose(fp); | |
return NULL; | |
} | |
fseek(fp, 0L, SEEK_SET); | |
char* buffer = malloc(size + 1); | |
if (!buffer) { | |
fprintf(stderr, "Error: could not allocate memory for system prompt\n"); | |
fclose(fp); | |
return NULL; | |
} | |
fread(buffer, 1, size, fp); | |
buffer[size] = '\0'; | |
fclose(fp); | |
return buffer; | |
} | |
// Conversation buffer structure | |
typedef struct { | |
char *data; | |
size_t size; | |
size_t capacity; | |
} conversation_buffer_t; | |
// Initialize conversation buffer | |
void init_conversation_buffer(conversation_buffer_t *buffer) { | |
buffer->data = malloc(CHUNK_SIZE); | |
if (buffer->data == NULL) { | |
fprintf(stderr, "Error allocating memory for conversation buffer.\n"); | |
exit(EXIT_FAILURE); | |
} | |
buffer->size = 0; | |
buffer->capacity = CHUNK_SIZE; | |
} | |
// Ensure the buffer has enough capacity, expand in 8MB chunks | |
void ensure_capacity(conversation_buffer_t *buffer, size_t additional_size) { | |
if (buffer->size + additional_size >= buffer->capacity) { | |
size_t new_capacity = buffer->capacity + CHUNK_SIZE; | |
char *new_data = realloc(buffer->data, new_capacity); | |
if (new_data == NULL) { | |
fprintf(stderr, "Error reallocating memory for conversation buffer.\n"); | |
free(buffer->data); | |
exit(EXIT_FAILURE); | |
} | |
buffer->data = new_data; | |
buffer->capacity = new_capacity; | |
} | |
} | |
// Append text to the conversation buffer | |
void append_to_buffer(conversation_buffer_t *buffer, const char *text) { | |
size_t text_len = strlen(text); | |
ensure_capacity(buffer, text_len); | |
memcpy(buffer->data + buffer->size, text, text_len); | |
buffer->size += text_len; | |
} | |
// Free the conversation buffer | |
void free_conversation_buffer(conversation_buffer_t *buffer) { | |
free(buffer->data); | |
buffer->data = NULL; | |
buffer->size = 0; | |
buffer->capacity = 0; | |
} | |
int run_chat_mode(struct llama_model *model, struct llama_context_params ctx_params) { | |
fprintf(stderr, "Creating new context with model...\n"); | |
struct llama_context *ctx = llama_new_context_with_model(model, ctx_params); | |
if (ctx == NULL) { | |
fprintf(stderr, "Failed to create context.\n"); | |
return EXIT_FAILURE; | |
} | |
fprintf(stderr, "Reading system prompt...\n"); | |
// Read system prompt from system.txt | |
char* system_prompt = read_system_prompt("system.txt"); | |
if (!system_prompt) { | |
llama_free(ctx); | |
return EXIT_FAILURE; | |
} | |
fprintf(stderr, "Initializing conversation buffer...\n"); | |
// Initialize conversation buffer | |
conversation_buffer_t convo_buffer; | |
init_conversation_buffer(&convo_buffer); | |
// Append system prompt to conversation buffer | |
append_to_buffer(&convo_buffer, system_prompt); | |
append_to_buffer(&convo_buffer, "\n"); // Separator | |
fprintf(stderr, "Initializing sampler chain...\n"); | |
// Initialize sampler chain | |
struct llama_sampler_chain_params sampler_params = llama_sampler_chain_default_params(); | |
struct llama_sampler *sampler = llama_sampler_chain_init(sampler_params); | |
if (sampler == NULL) { | |
fprintf(stderr, "Failed to initialize sampler chain.\n"); | |
llama_free(ctx); | |
free(system_prompt); | |
free_conversation_buffer(&convo_buffer); | |
return EXIT_FAILURE; | |
} | |
// Set up sampling parameters | |
int top_k = 40; | |
float top_p = 0.9f; // Adjusted | |
float temperature = 0.7f; // Adjusted | |
fprintf(stderr, "Adding repetition penalty sampler...\n"); | |
// Add repetition penalty sampler | |
struct llama_sampler *penalty_sampler = llama_sampler_init_penalties( | |
llama_n_vocab(model), | |
llama_token_eos(model), | |
llama_token_nl(model), | |
64, // penalty_last_n | |
1.2f, // penalty_repeat | |
0.0f, // penalty_freq | |
0.0f, // penalty_present | |
false, // penalize_nl | |
false // ignore_eos | |
); | |
if (penalty_sampler == NULL) { | |
fprintf(stderr, "Failed to initialize penalty sampler.\n"); | |
// Handle error | |
} | |
llama_sampler_chain_add(sampler, penalty_sampler); | |
// Add logit bias to prevent "User:" and "Assistant:" | |
fprintf(stderr, "Adding logit bias sampler...\n"); | |
const char* forbidden_strings[] = {"User:", "Assistant:"}; | |
int n_forbidden = sizeof(forbidden_strings) / sizeof(forbidden_strings[0]); | |
// Assuming max 10 tokens per string | |
const int max_tokens_per_string = 10; | |
int total_tokens = n_forbidden * max_tokens_per_string; | |
llama_logit_bias *logit_biases = malloc(sizeof(llama_logit_bias) * total_tokens); | |
int idx = 0; | |
for (int i = 0; i < n_forbidden; ++i) { | |
const char* str = forbidden_strings[i]; | |
llama_token tokens[max_tokens_per_string]; | |
int n_tokens = llama_tokenize(model, str, strlen(str), tokens, max_tokens_per_string, true, false); | |
if (n_tokens < 0) { | |
fprintf(stderr, "Failed to tokenize '%s'.\n", str); | |
llama_sampler_free(sampler); | |
llama_free(ctx); | |
free(system_prompt); | |
free_conversation_buffer(&convo_buffer); | |
free(logit_biases); | |
return EXIT_FAILURE; | |
} | |
for (int j = 0; j < n_tokens; ++j) { | |
logit_biases[idx].token = tokens[j]; | |
logit_biases[idx].bias = -INFINITY; | |
idx++; | |
} | |
} | |
// Update total_tokens to actual number | |
total_tokens = idx; | |
// Add logit bias sampler | |
struct llama_sampler *logit_bias_sampler = llama_sampler_init_logit_bias( | |
llama_n_vocab(model), | |
total_tokens, | |
logit_biases | |
); | |
if (logit_bias_sampler == NULL) { | |
fprintf(stderr, "Failed to initialize logit bias sampler.\n"); | |
// Handle error | |
} | |
llama_sampler_chain_add(sampler, logit_bias_sampler); | |
free(logit_biases); | |
fprintf(stderr, "Adding temperature, top_k, and top_p samplers...\n"); | |
// Add temperature sampler | |
struct llama_sampler *temp_sampler = llama_sampler_init_temp(temperature); | |
llama_sampler_chain_add(sampler, temp_sampler); | |
// Add top_k sampler | |
struct llama_sampler *top_k_sampler = llama_sampler_init_top_k(top_k); | |
llama_sampler_chain_add(sampler, top_k_sampler); | |
// Add top_p sampler | |
struct llama_sampler *top_p_sampler = llama_sampler_init_top_p(top_p, 1); // min_keep = 1 | |
llama_sampler_chain_add(sampler, top_p_sampler); | |
// Add final sampler to select the token | |
struct llama_sampler *dist_sampler = llama_sampler_init_dist(LLAMA_DEFAULT_SEED); | |
llama_sampler_chain_add(sampler, dist_sampler); | |
// Main chat loop | |
bool is_running = true; | |
fprintf(stderr, "Entering chat loop...\n"); // Debugging statement | |
while (is_running) { | |
printf("\n> "); // Prompt the user (to stdout) | |
fflush(stdout); // Ensure prompt is displayed immediately | |
char user_input[4096]; | |
if (!fgets(user_input, sizeof(user_input), stdin)) { | |
// EOF or error | |
fprintf(stderr, "No input received. Exiting chat loop.\n"); // Debugging statement | |
break; | |
} | |
// Remove trailing newline | |
size_t len = strlen(user_input); | |
if (len > 0 && user_input[len - 1] == '\n') { | |
user_input[len - 1] = '\0'; | |
} | |
// Check for exit command | |
if (strcmp(user_input, "exit") == 0) { | |
break; | |
} | |
// Append user input to conversation buffer | |
append_to_buffer(&convo_buffer, "User: "); | |
append_to_buffer(&convo_buffer, user_input); | |
append_to_buffer(&convo_buffer, "\nAssistant: "); | |
// Build the full prompt from conversation buffer | |
char *full_prompt = malloc(convo_buffer.size + 1); | |
if (!full_prompt) { | |
fprintf(stderr, "Error allocating memory for full prompt.\n"); | |
break; | |
} | |
memcpy(full_prompt, convo_buffer.data, convo_buffer.size); | |
full_prompt[convo_buffer.size] = '\0'; | |
// Tokenize prompt | |
int32_t n_tokens_max = ctx_params.n_ctx; | |
llama_token *tokens = malloc(sizeof(llama_token) * n_tokens_max); | |
if (!tokens) { | |
fprintf(stderr, "Error allocating memory for tokens.\n"); | |
free(full_prompt); | |
break; | |
} | |
int32_t n_tokens = llama_tokenize( | |
model, | |
full_prompt, | |
strlen(full_prompt), | |
tokens, | |
n_tokens_max, | |
true, | |
false | |
); | |
if (n_tokens < 0) { | |
fprintf(stderr, "Tokenization failed. Required tokens: %d\n", -n_tokens); | |
llama_sampler_free(sampler); | |
llama_free(ctx); | |
free(system_prompt); | |
free(full_prompt); | |
free(tokens); | |
free_conversation_buffer(&convo_buffer); | |
return EXIT_FAILURE; | |
} | |
// Ensure the number of tokens doesn't exceed the context size | |
if (n_tokens > ctx_params.n_ctx) { | |
fprintf(stderr, "Context window exceeded. Trimming oldest messages.\n"); | |
// Trim the oldest messages | |
int tokens_to_trim = n_tokens - ctx_params.n_ctx; | |
memmove(tokens, tokens + tokens_to_trim, (n_tokens - tokens_to_trim) * sizeof(llama_token)); | |
n_tokens -= tokens_to_trim; | |
// Adjust the conversation buffer accordingly (this is approximate) | |
size_t bytes_to_trim = (size_t)((float)tokens_to_trim / n_tokens * convo_buffer.size); | |
memmove(convo_buffer.data, convo_buffer.data + bytes_to_trim, convo_buffer.size - bytes_to_trim); | |
convo_buffer.size -= bytes_to_trim; | |
} | |
// Initialize batch | |
struct llama_batch batch; | |
memset(&batch, 0, sizeof(struct llama_batch)); | |
batch.n_tokens = n_tokens; | |
batch.token = tokens; | |
batch.embd = NULL; | |
batch.all_pos_0 = 0; | |
batch.all_pos_1 = 1; | |
batch.all_seq_id = 0; | |
// Decode tokens to get initial logits | |
int32_t decode_status = llama_decode(ctx, batch); | |
if (decode_status < 0) { | |
fprintf(stderr, "Decoding failed with status: %d\n", decode_status); | |
llama_sampler_free(sampler); | |
llama_free(ctx); | |
free(system_prompt); | |
free(full_prompt); | |
free(tokens); | |
free_conversation_buffer(&convo_buffer); | |
return EXIT_FAILURE; | |
} | |
int32_t n_past = n_tokens; | |
int32_t max_generation_tokens = MAX_GENERATION_TOKENS; | |
// Generate tokens | |
for (int i = 0; i < max_generation_tokens; ++i) { | |
// Retrieve logits | |
float *logits = llama_get_logits(ctx); | |
if (logits == NULL) { | |
fprintf(stderr, "Failed to retrieve logits.\n"); | |
break; | |
} | |
// Sample next token | |
llama_token sampled_token = llama_sampler_sample(sampler, ctx, -1); // -1 for the last token | |
if (sampled_token < 0) { | |
fprintf(stderr, "Sampling failed.\n"); | |
break; | |
} | |
// Check for EOS | |
if (sampled_token == llama_token_eos(model)) { | |
break; | |
} | |
// Convert token to text | |
char token_text[256]; | |
int32_t token_length = llama_token_to_piece( | |
model, | |
sampled_token, | |
token_text, | |
sizeof(token_text), | |
0, | |
false | |
); | |
if (token_length < 0) { | |
fprintf(stderr, "Detokenization failed for token ID: %d\n", sampled_token); | |
break; | |
} | |
token_text[token_length] = '\0'; | |
// Print token (this is the chat output) | |
printf("%s", token_text); | |
fflush(stdout); | |
// Append assistant's reply to conversation buffer | |
append_to_buffer(&convo_buffer, token_text); | |
// Stop generation if assistant ends its reply (detect newline) | |
if (strchr(token_text, '\n') != NULL) { | |
break; | |
} | |
// Prepare batch for the sampled token | |
struct llama_batch gen_batch; | |
memset(&gen_batch, 0, sizeof(struct llama_batch)); | |
gen_batch.n_tokens = 1; | |
llama_token token_array[1] = { sampled_token }; | |
gen_batch.token = token_array; | |
gen_batch.embd = NULL; | |
gen_batch.all_pos_0 = n_past; | |
gen_batch.all_pos_1 = 1; | |
gen_batch.all_seq_id = 0; | |
// Decode the new token | |
decode_status = llama_decode(ctx, gen_batch); | |
if (decode_status < 0) { | |
fprintf(stderr, "Decoding failed with status: %d\n", decode_status); | |
break; | |
} | |
// Accept the sampled token in the sampler | |
llama_sampler_accept(sampler, sampled_token); | |
// Update n_past | |
n_past += 1; | |
} | |
// Add a newline to separate exchanges | |
append_to_buffer(&convo_buffer, "\n"); | |
free(full_prompt); | |
free(tokens); | |
} | |
// Clean up | |
llama_sampler_free(sampler); | |
llama_free(ctx); | |
free(system_prompt); | |
free_conversation_buffer(&convo_buffer); | |
return EXIT_SUCCESS; | |
} | |
int main(int argc, char **argv) { | |
fprintf(stderr, "Loading...\n"); | |
// Initialize llama backend | |
llama_backend_init(); | |
// Load the model | |
struct llama_model_params model_params = llama_model_default_params(); | |
fprintf(stderr, "Loading model from '%s'...\n", MODEL_PATH); | |
struct llama_model *model = llama_load_model_from_file(MODEL_PATH, model_params); | |
if (model == NULL) { | |
fprintf(stderr, "Failed to load model from %s\n", MODEL_PATH); | |
return EXIT_FAILURE; | |
} | |
// Initialize context parameters | |
struct llama_context_params ctx_params = llama_context_default_params(); | |
ctx_params.n_ctx = 2048; // Adjusted to a reasonable context size | |
ctx_params.n_batch = 512; | |
ctx_params.n_threads = 8; // Adjust based on your system | |
ctx_params.embeddings = false; | |
int ret = run_chat_mode(model, ctx_params); | |
if (ret != EXIT_SUCCESS) { | |
fprintf(stderr, "run_chat_mode failed with status: %d\n", ret); | |
} | |
// Clean up | |
llama_free_model(model); | |
llama_backend_free(); | |
return ret; | |
} | |
//this code works fine, but after some back and forth interactions, the model starts to speak nonsense/partial tokens and doesn't remember what we're talking about... |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment