Skip to content

Instantly share code, notes, and snippets.

@ninjadynamics
Created October 3, 2024 14:57
Show Gist options
  • Save ninjadynamics/6d1eb01401dd6822293940676e84decd to your computer and use it in GitHub Desktop.
Save ninjadynamics/6d1eb01401dd6822293940676e84decd to your computer and use it in GitHub Desktop.
LlamaCPP Chat Inference in C
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <stdint.h>
#include <stdbool.h>
#include <math.h> // For INFINITY
#include "llama.h"
#define MODEL_PATH "llama-3.2-1b-instruct-q8_0.gguf"
#define MAX_GENERATION_TOKENS 2000
#define CHUNK_SIZE (8 * 1024 * 1024) // 8MB
// Function to read the system prompt from system.txt
char* read_system_prompt(const char* filename) {
FILE* fp = fopen(filename, "r");
if (!fp) {
fprintf(stderr, "Error: could not open system prompt file '%s'\n", filename);
return NULL;
}
fseek(fp, 0L, SEEK_END);
size_t size = ftell(fp);
if (size == 0) {
fprintf(stderr, "Error: system prompt file '%s' is empty\n", filename);
fclose(fp);
return NULL;
}
fseek(fp, 0L, SEEK_SET);
char* buffer = malloc(size + 1);
if (!buffer) {
fprintf(stderr, "Error: could not allocate memory for system prompt\n");
fclose(fp);
return NULL;
}
fread(buffer, 1, size, fp);
buffer[size] = '\0';
fclose(fp);
return buffer;
}
// Conversation buffer structure
typedef struct {
char *data;
size_t size;
size_t capacity;
} conversation_buffer_t;
// Initialize conversation buffer
void init_conversation_buffer(conversation_buffer_t *buffer) {
buffer->data = malloc(CHUNK_SIZE);
if (buffer->data == NULL) {
fprintf(stderr, "Error allocating memory for conversation buffer.\n");
exit(EXIT_FAILURE);
}
buffer->size = 0;
buffer->capacity = CHUNK_SIZE;
}
// Ensure the buffer has enough capacity, expand in 8MB chunks
void ensure_capacity(conversation_buffer_t *buffer, size_t additional_size) {
if (buffer->size + additional_size >= buffer->capacity) {
size_t new_capacity = buffer->capacity + CHUNK_SIZE;
char *new_data = realloc(buffer->data, new_capacity);
if (new_data == NULL) {
fprintf(stderr, "Error reallocating memory for conversation buffer.\n");
free(buffer->data);
exit(EXIT_FAILURE);
}
buffer->data = new_data;
buffer->capacity = new_capacity;
}
}
// Append text to the conversation buffer
void append_to_buffer(conversation_buffer_t *buffer, const char *text) {
size_t text_len = strlen(text);
ensure_capacity(buffer, text_len);
memcpy(buffer->data + buffer->size, text, text_len);
buffer->size += text_len;
}
// Free the conversation buffer
void free_conversation_buffer(conversation_buffer_t *buffer) {
free(buffer->data);
buffer->data = NULL;
buffer->size = 0;
buffer->capacity = 0;
}
int run_chat_mode(struct llama_model *model, struct llama_context_params ctx_params) {
fprintf(stderr, "Creating new context with model...\n");
struct llama_context *ctx = llama_new_context_with_model(model, ctx_params);
if (ctx == NULL) {
fprintf(stderr, "Failed to create context.\n");
return EXIT_FAILURE;
}
fprintf(stderr, "Reading system prompt...\n");
// Read system prompt from system.txt
char* system_prompt = read_system_prompt("system.txt");
if (!system_prompt) {
llama_free(ctx);
return EXIT_FAILURE;
}
fprintf(stderr, "Initializing conversation buffer...\n");
// Initialize conversation buffer
conversation_buffer_t convo_buffer;
init_conversation_buffer(&convo_buffer);
// Append system prompt to conversation buffer
append_to_buffer(&convo_buffer, system_prompt);
append_to_buffer(&convo_buffer, "\n"); // Separator
fprintf(stderr, "Initializing sampler chain...\n");
// Initialize sampler chain
struct llama_sampler_chain_params sampler_params = llama_sampler_chain_default_params();
struct llama_sampler *sampler = llama_sampler_chain_init(sampler_params);
if (sampler == NULL) {
fprintf(stderr, "Failed to initialize sampler chain.\n");
llama_free(ctx);
free(system_prompt);
free_conversation_buffer(&convo_buffer);
return EXIT_FAILURE;
}
// Set up sampling parameters
int top_k = 40;
float top_p = 0.9f; // Adjusted
float temperature = 0.7f; // Adjusted
fprintf(stderr, "Adding repetition penalty sampler...\n");
// Add repetition penalty sampler
struct llama_sampler *penalty_sampler = llama_sampler_init_penalties(
llama_n_vocab(model),
llama_token_eos(model),
llama_token_nl(model),
64, // penalty_last_n
1.2f, // penalty_repeat
0.0f, // penalty_freq
0.0f, // penalty_present
false, // penalize_nl
false // ignore_eos
);
if (penalty_sampler == NULL) {
fprintf(stderr, "Failed to initialize penalty sampler.\n");
// Handle error
}
llama_sampler_chain_add(sampler, penalty_sampler);
// Add logit bias to prevent "User:" and "Assistant:"
fprintf(stderr, "Adding logit bias sampler...\n");
const char* forbidden_strings[] = {"User:", "Assistant:"};
int n_forbidden = sizeof(forbidden_strings) / sizeof(forbidden_strings[0]);
// Assuming max 10 tokens per string
const int max_tokens_per_string = 10;
int total_tokens = n_forbidden * max_tokens_per_string;
llama_logit_bias *logit_biases = malloc(sizeof(llama_logit_bias) * total_tokens);
int idx = 0;
for (int i = 0; i < n_forbidden; ++i) {
const char* str = forbidden_strings[i];
llama_token tokens[max_tokens_per_string];
int n_tokens = llama_tokenize(model, str, strlen(str), tokens, max_tokens_per_string, true, false);
if (n_tokens < 0) {
fprintf(stderr, "Failed to tokenize '%s'.\n", str);
llama_sampler_free(sampler);
llama_free(ctx);
free(system_prompt);
free_conversation_buffer(&convo_buffer);
free(logit_biases);
return EXIT_FAILURE;
}
for (int j = 0; j < n_tokens; ++j) {
logit_biases[idx].token = tokens[j];
logit_biases[idx].bias = -INFINITY;
idx++;
}
}
// Update total_tokens to actual number
total_tokens = idx;
// Add logit bias sampler
struct llama_sampler *logit_bias_sampler = llama_sampler_init_logit_bias(
llama_n_vocab(model),
total_tokens,
logit_biases
);
if (logit_bias_sampler == NULL) {
fprintf(stderr, "Failed to initialize logit bias sampler.\n");
// Handle error
}
llama_sampler_chain_add(sampler, logit_bias_sampler);
free(logit_biases);
fprintf(stderr, "Adding temperature, top_k, and top_p samplers...\n");
// Add temperature sampler
struct llama_sampler *temp_sampler = llama_sampler_init_temp(temperature);
llama_sampler_chain_add(sampler, temp_sampler);
// Add top_k sampler
struct llama_sampler *top_k_sampler = llama_sampler_init_top_k(top_k);
llama_sampler_chain_add(sampler, top_k_sampler);
// Add top_p sampler
struct llama_sampler *top_p_sampler = llama_sampler_init_top_p(top_p, 1); // min_keep = 1
llama_sampler_chain_add(sampler, top_p_sampler);
// Add final sampler to select the token
struct llama_sampler *dist_sampler = llama_sampler_init_dist(LLAMA_DEFAULT_SEED);
llama_sampler_chain_add(sampler, dist_sampler);
// Main chat loop
bool is_running = true;
fprintf(stderr, "Entering chat loop...\n"); // Debugging statement
while (is_running) {
printf("\n> "); // Prompt the user (to stdout)
fflush(stdout); // Ensure prompt is displayed immediately
char user_input[4096];
if (!fgets(user_input, sizeof(user_input), stdin)) {
// EOF or error
fprintf(stderr, "No input received. Exiting chat loop.\n"); // Debugging statement
break;
}
// Remove trailing newline
size_t len = strlen(user_input);
if (len > 0 && user_input[len - 1] == '\n') {
user_input[len - 1] = '\0';
}
// Check for exit command
if (strcmp(user_input, "exit") == 0) {
break;
}
// Append user input to conversation buffer
append_to_buffer(&convo_buffer, "User: ");
append_to_buffer(&convo_buffer, user_input);
append_to_buffer(&convo_buffer, "\nAssistant: ");
// Build the full prompt from conversation buffer
char *full_prompt = malloc(convo_buffer.size + 1);
if (!full_prompt) {
fprintf(stderr, "Error allocating memory for full prompt.\n");
break;
}
memcpy(full_prompt, convo_buffer.data, convo_buffer.size);
full_prompt[convo_buffer.size] = '\0';
// Tokenize prompt
int32_t n_tokens_max = ctx_params.n_ctx;
llama_token *tokens = malloc(sizeof(llama_token) * n_tokens_max);
if (!tokens) {
fprintf(stderr, "Error allocating memory for tokens.\n");
free(full_prompt);
break;
}
int32_t n_tokens = llama_tokenize(
model,
full_prompt,
strlen(full_prompt),
tokens,
n_tokens_max,
true,
false
);
if (n_tokens < 0) {
fprintf(stderr, "Tokenization failed. Required tokens: %d\n", -n_tokens);
llama_sampler_free(sampler);
llama_free(ctx);
free(system_prompt);
free(full_prompt);
free(tokens);
free_conversation_buffer(&convo_buffer);
return EXIT_FAILURE;
}
// Ensure the number of tokens doesn't exceed the context size
if (n_tokens > ctx_params.n_ctx) {
fprintf(stderr, "Context window exceeded. Trimming oldest messages.\n");
// Trim the oldest messages
int tokens_to_trim = n_tokens - ctx_params.n_ctx;
memmove(tokens, tokens + tokens_to_trim, (n_tokens - tokens_to_trim) * sizeof(llama_token));
n_tokens -= tokens_to_trim;
// Adjust the conversation buffer accordingly (this is approximate)
size_t bytes_to_trim = (size_t)((float)tokens_to_trim / n_tokens * convo_buffer.size);
memmove(convo_buffer.data, convo_buffer.data + bytes_to_trim, convo_buffer.size - bytes_to_trim);
convo_buffer.size -= bytes_to_trim;
}
// Initialize batch
struct llama_batch batch;
memset(&batch, 0, sizeof(struct llama_batch));
batch.n_tokens = n_tokens;
batch.token = tokens;
batch.embd = NULL;
batch.all_pos_0 = 0;
batch.all_pos_1 = 1;
batch.all_seq_id = 0;
// Decode tokens to get initial logits
int32_t decode_status = llama_decode(ctx, batch);
if (decode_status < 0) {
fprintf(stderr, "Decoding failed with status: %d\n", decode_status);
llama_sampler_free(sampler);
llama_free(ctx);
free(system_prompt);
free(full_prompt);
free(tokens);
free_conversation_buffer(&convo_buffer);
return EXIT_FAILURE;
}
int32_t n_past = n_tokens;
int32_t max_generation_tokens = MAX_GENERATION_TOKENS;
// Generate tokens
for (int i = 0; i < max_generation_tokens; ++i) {
// Retrieve logits
float *logits = llama_get_logits(ctx);
if (logits == NULL) {
fprintf(stderr, "Failed to retrieve logits.\n");
break;
}
// Sample next token
llama_token sampled_token = llama_sampler_sample(sampler, ctx, -1); // -1 for the last token
if (sampled_token < 0) {
fprintf(stderr, "Sampling failed.\n");
break;
}
// Check for EOS
if (sampled_token == llama_token_eos(model)) {
break;
}
// Convert token to text
char token_text[256];
int32_t token_length = llama_token_to_piece(
model,
sampled_token,
token_text,
sizeof(token_text),
0,
false
);
if (token_length < 0) {
fprintf(stderr, "Detokenization failed for token ID: %d\n", sampled_token);
break;
}
token_text[token_length] = '\0';
// Print token (this is the chat output)
printf("%s", token_text);
fflush(stdout);
// Append assistant's reply to conversation buffer
append_to_buffer(&convo_buffer, token_text);
// Stop generation if assistant ends its reply (detect newline)
if (strchr(token_text, '\n') != NULL) {
break;
}
// Prepare batch for the sampled token
struct llama_batch gen_batch;
memset(&gen_batch, 0, sizeof(struct llama_batch));
gen_batch.n_tokens = 1;
llama_token token_array[1] = { sampled_token };
gen_batch.token = token_array;
gen_batch.embd = NULL;
gen_batch.all_pos_0 = n_past;
gen_batch.all_pos_1 = 1;
gen_batch.all_seq_id = 0;
// Decode the new token
decode_status = llama_decode(ctx, gen_batch);
if (decode_status < 0) {
fprintf(stderr, "Decoding failed with status: %d\n", decode_status);
break;
}
// Accept the sampled token in the sampler
llama_sampler_accept(sampler, sampled_token);
// Update n_past
n_past += 1;
}
// Add a newline to separate exchanges
append_to_buffer(&convo_buffer, "\n");
free(full_prompt);
free(tokens);
}
// Clean up
llama_sampler_free(sampler);
llama_free(ctx);
free(system_prompt);
free_conversation_buffer(&convo_buffer);
return EXIT_SUCCESS;
}
int main(int argc, char **argv) {
fprintf(stderr, "Loading...\n");
// Initialize llama backend
llama_backend_init();
// Load the model
struct llama_model_params model_params = llama_model_default_params();
fprintf(stderr, "Loading model from '%s'...\n", MODEL_PATH);
struct llama_model *model = llama_load_model_from_file(MODEL_PATH, model_params);
if (model == NULL) {
fprintf(stderr, "Failed to load model from %s\n", MODEL_PATH);
return EXIT_FAILURE;
}
// Initialize context parameters
struct llama_context_params ctx_params = llama_context_default_params();
ctx_params.n_ctx = 2048; // Adjusted to a reasonable context size
ctx_params.n_batch = 512;
ctx_params.n_threads = 8; // Adjust based on your system
ctx_params.embeddings = false;
int ret = run_chat_mode(model, ctx_params);
if (ret != EXIT_SUCCESS) {
fprintf(stderr, "run_chat_mode failed with status: %d\n", ret);
}
// Clean up
llama_free_model(model);
llama_backend_free();
return ret;
}
//this code works fine, but after some back and forth interactions, the model starts to speak nonsense/partial tokens and doesn't remember what we're talking about...
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment