Skip to content

Instantly share code, notes, and snippets.

@mukel
Created November 11, 2024 10:04
Show Gist options
  • Save mukel/b887c5ecc3ff2c7f8730d9cedaf4bb3b to your computer and use it in GitHub Desktop.
Save mukel/b887c5ecc3ff2c7f8730d9cedaf4bb3b to your computer and use it in GitHub Desktop.
Simple Llama3.java integration demo with LangChain4j
///usr/bin/env jbang "$0" "$@" ; exit $?
//JAVA 21+
//PREVIEW
//COMPILE_OPTIONS --add-modules=jdk.incubator.vector
//RUNTIME_OPTIONS --add-modules=jdk.incubator.vector
//DEPS dev.langchain4j:langchain4j:0.35.0
//SOURCES https://github.com/mukel/llama3.java/blob/main/Llama3.java
// Simple integration demo with LangChain4j, run as follows:
// jbang LangChain4jDemo.java /path/to/model.gguf
// For additional details and instructions on where/how to download models
// visit https://github.com/mukel/llama3.java
package com.llama4j;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import dev.langchain4j.model.output.FinishReason;
import dev.langchain4j.service.AiServices;
import dev.langchain4j.service.UserMessage;
import dev.langchain4j.service.SystemMessage;
import dev.langchain4j.service.TokenStream;
import java.io.IOException;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.List;
import java.util.Scanner;
public class LangChain4jDemo {
interface Assistant {
@SystemMessage("You are a helpful assistant")
String chat(String userMessage);
}
interface StreamingAssistant {
@SystemMessage("You are a helpful assistant")
TokenStream chat(String userMessage);
}
public static void main(String[] args) throws IOException {
ChatLanguageModel model = new Llama3ChatLanguageModel(Path.of(args[0]));
// StreamingChatLanguageModel model = new Llama3StreamingChatLanguageModel(Paths.get(args[0]));
Assistant ai = AiServices.create(Assistant.class, model);
// StreamingAssistant ai = AiServices.create(StreamingAssistant.class, model);
try (Scanner in = new Scanner(System.in)) {
while (true) {
System.out.print("> ");
String input = in.nextLine();
System.out.println(ai.chat(input));
// TokenStream stream = ai.chat(input);
// stream.onNext(System.out::print)
// //.onComplete(System.out::println)
// .onError(Throwable::printStackTrace)
// .start();
}
}
}
}
abstract class Llama3Base {
final Llama model;
final ChatFormat chatFormat;
final Sampler sampler;
final int maxTokens;
Llama3Base(Path ggufPath) throws IOException {
this(ggufPath, 512, 0.1f, 0.95f);
}
Llama3Base(Path ggufPath, int maxTokens, float temperature, float topp) throws IOException {
this.model = ModelLoader.loadModel(ggufPath, maxTokens, true);
this.chatFormat = new ChatFormat(model.tokenizer());
this.maxTokens = (maxTokens < 0) ? model.configuration().contextLength : maxTokens;
assert 0 <= maxTokens && maxTokens <= model.configuration().contextLength;
assert 0 <= temperature && temperature <= 1;
assert 0 <= topp && topp <= 1;
this.sampler = Llama3.selectSampler(model.configuration().vocabularySize, temperature, topp, 42);
}
static ChatFormat.Message toLlamaMessage(ChatMessage chatMessage) {
return switch (chatMessage) {
case dev.langchain4j.data.message.UserMessage userMessage -> {
String name = userMessage.name();
ChatFormat.Role role = (name == null || "user".equals(name)) ? ChatFormat.Role.USER : new ChatFormat.Role(name);
yield new ChatFormat.Message(role, userMessage.singleText());
}
case dev.langchain4j.data.message.SystemMessage systemMessage ->
new ChatFormat.Message(ChatFormat.Role.SYSTEM, systemMessage.text());
case AiMessage aiMessage -> new ChatFormat.Message(ChatFormat.Role.ASSISTANT, aiMessage.text());
default -> throw new IllegalArgumentException("Cannot convert to Llama message");
};
}
}
class Llama3ChatLanguageModel extends Llama3Base implements ChatLanguageModel {
Llama3ChatLanguageModel(Path ggufPath) throws IOException {
super(ggufPath, 512, 0.1f, 0.95f);
}
@Override
public Response<AiMessage> generate(List<ChatMessage> messages) {
List<ChatFormat.Message> llamaMessages = messages.stream().map(Llama3ChatLanguageModel::toLlamaMessage).toList();
List<Integer> promptTokens = chatFormat.encodeDialogPrompt(true, llamaMessages);
List<Integer> responseTokens = new ArrayList<>();
Llama.generateTokens(model, model.createNewState(), 0, promptTokens, chatFormat.getStopTokens(), maxTokens, sampler, false, responseTokens::add);
TokenUsage tokenUsage = new TokenUsage(promptTokens.size(), responseTokens.size());
FinishReason finishReason = FinishReason.LENGTH;
if (!responseTokens.isEmpty() && chatFormat.getStopTokens().contains(responseTokens.getLast())) {
finishReason = FinishReason.STOP;
responseTokens.removeLast(); // drop stop token from answer
}
String responseText = model.tokenizer().decode(responseTokens);
return Response.from(AiMessage.from(responseText), tokenUsage, finishReason);
}
}
class Llama3StreamingChatLanguageModel extends Llama3Base implements StreamingChatLanguageModel {
Llama3StreamingChatLanguageModel(Path ggufPath) throws IOException {
super(ggufPath, 512, 0.1f, 0.95f);
}
@Override
public void generate(List<ChatMessage> messages, StreamingResponseHandler<AiMessage> handler) {
List<ChatFormat.Message> llamaMessages = messages.stream().map(Llama3ChatLanguageModel::toLlamaMessage).toList();
List<Integer> promptTokens = chatFormat.encodeDialogPrompt(true, llamaMessages);
List<Integer> responseTokens = new ArrayList<>();
Llama.generateTokens(model, model.createNewState(), 0, promptTokens, chatFormat.getStopTokens(), maxTokens, sampler, false, token -> {
responseTokens.add(token);
if (!chatFormat.getStopTokens().contains(token)) {
handler.onNext(model.tokenizer().decode(List.of(token)));
}
});
TokenUsage tokenUsage = new TokenUsage(promptTokens.size(), responseTokens.size());
FinishReason finishReason = FinishReason.LENGTH;
if (!responseTokens.isEmpty() && chatFormat.getStopTokens().contains(responseTokens.getLast())) {
finishReason = FinishReason.STOP;
responseTokens.removeLast(); // drop stop token from answer
}
String responseText = model.tokenizer().decode(responseTokens);
handler.onComplete(Response.from(AiMessage.from(responseText)));
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment