Created
November 11, 2024 10:04
-
-
Save mukel/b887c5ecc3ff2c7f8730d9cedaf4bb3b to your computer and use it in GitHub Desktop.
Simple Llama3.java integration demo with LangChain4j
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
///usr/bin/env jbang "$0" "$@" ; exit $? | |
//JAVA 21+ | |
//PREVIEW | |
//COMPILE_OPTIONS --add-modules=jdk.incubator.vector | |
//RUNTIME_OPTIONS --add-modules=jdk.incubator.vector | |
//DEPS dev.langchain4j:langchain4j:0.35.0 | |
//SOURCES https://github.com/mukel/llama3.java/blob/main/Llama3.java | |
// Simple integration demo with LangChain4j, run as follows: | |
// jbang LangChain4jDemo.java /path/to/model.gguf | |
// For additional details and instructions on where/how to download models | |
// visit https://github.com/mukel/llama3.java | |
package com.llama4j; | |
import dev.langchain4j.data.message.AiMessage; | |
import dev.langchain4j.data.message.ChatMessage; | |
import dev.langchain4j.model.StreamingResponseHandler; | |
import dev.langchain4j.model.chat.ChatLanguageModel; | |
import dev.langchain4j.model.chat.StreamingChatLanguageModel; | |
import dev.langchain4j.model.output.Response; | |
import dev.langchain4j.model.output.TokenUsage; | |
import dev.langchain4j.model.output.FinishReason; | |
import dev.langchain4j.service.AiServices; | |
import dev.langchain4j.service.UserMessage; | |
import dev.langchain4j.service.SystemMessage; | |
import dev.langchain4j.service.TokenStream; | |
import java.io.IOException; | |
import java.nio.file.Path; | |
import java.util.ArrayList; | |
import java.util.List; | |
import java.util.Scanner; | |
public class LangChain4jDemo { | |
interface Assistant { | |
@SystemMessage("You are a helpful assistant") | |
String chat(String userMessage); | |
} | |
interface StreamingAssistant { | |
@SystemMessage("You are a helpful assistant") | |
TokenStream chat(String userMessage); | |
} | |
public static void main(String[] args) throws IOException { | |
ChatLanguageModel model = new Llama3ChatLanguageModel(Path.of(args[0])); | |
// StreamingChatLanguageModel model = new Llama3StreamingChatLanguageModel(Paths.get(args[0])); | |
Assistant ai = AiServices.create(Assistant.class, model); | |
// StreamingAssistant ai = AiServices.create(StreamingAssistant.class, model); | |
try (Scanner in = new Scanner(System.in)) { | |
while (true) { | |
System.out.print("> "); | |
String input = in.nextLine(); | |
System.out.println(ai.chat(input)); | |
// TokenStream stream = ai.chat(input); | |
// stream.onNext(System.out::print) | |
// //.onComplete(System.out::println) | |
// .onError(Throwable::printStackTrace) | |
// .start(); | |
} | |
} | |
} | |
} | |
abstract class Llama3Base { | |
final Llama model; | |
final ChatFormat chatFormat; | |
final Sampler sampler; | |
final int maxTokens; | |
Llama3Base(Path ggufPath) throws IOException { | |
this(ggufPath, 512, 0.1f, 0.95f); | |
} | |
Llama3Base(Path ggufPath, int maxTokens, float temperature, float topp) throws IOException { | |
this.model = ModelLoader.loadModel(ggufPath, maxTokens, true); | |
this.chatFormat = new ChatFormat(model.tokenizer()); | |
this.maxTokens = (maxTokens < 0) ? model.configuration().contextLength : maxTokens; | |
assert 0 <= maxTokens && maxTokens <= model.configuration().contextLength; | |
assert 0 <= temperature && temperature <= 1; | |
assert 0 <= topp && topp <= 1; | |
this.sampler = Llama3.selectSampler(model.configuration().vocabularySize, temperature, topp, 42); | |
} | |
static ChatFormat.Message toLlamaMessage(ChatMessage chatMessage) { | |
return switch (chatMessage) { | |
case dev.langchain4j.data.message.UserMessage userMessage -> { | |
String name = userMessage.name(); | |
ChatFormat.Role role = (name == null || "user".equals(name)) ? ChatFormat.Role.USER : new ChatFormat.Role(name); | |
yield new ChatFormat.Message(role, userMessage.singleText()); | |
} | |
case dev.langchain4j.data.message.SystemMessage systemMessage -> | |
new ChatFormat.Message(ChatFormat.Role.SYSTEM, systemMessage.text()); | |
case AiMessage aiMessage -> new ChatFormat.Message(ChatFormat.Role.ASSISTANT, aiMessage.text()); | |
default -> throw new IllegalArgumentException("Cannot convert to Llama message"); | |
}; | |
} | |
} | |
class Llama3ChatLanguageModel extends Llama3Base implements ChatLanguageModel { | |
Llama3ChatLanguageModel(Path ggufPath) throws IOException { | |
super(ggufPath, 512, 0.1f, 0.95f); | |
} | |
@Override | |
public Response<AiMessage> generate(List<ChatMessage> messages) { | |
List<ChatFormat.Message> llamaMessages = messages.stream().map(Llama3ChatLanguageModel::toLlamaMessage).toList(); | |
List<Integer> promptTokens = chatFormat.encodeDialogPrompt(true, llamaMessages); | |
List<Integer> responseTokens = new ArrayList<>(); | |
Llama.generateTokens(model, model.createNewState(), 0, promptTokens, chatFormat.getStopTokens(), maxTokens, sampler, false, responseTokens::add); | |
TokenUsage tokenUsage = new TokenUsage(promptTokens.size(), responseTokens.size()); | |
FinishReason finishReason = FinishReason.LENGTH; | |
if (!responseTokens.isEmpty() && chatFormat.getStopTokens().contains(responseTokens.getLast())) { | |
finishReason = FinishReason.STOP; | |
responseTokens.removeLast(); // drop stop token from answer | |
} | |
String responseText = model.tokenizer().decode(responseTokens); | |
return Response.from(AiMessage.from(responseText), tokenUsage, finishReason); | |
} | |
} | |
class Llama3StreamingChatLanguageModel extends Llama3Base implements StreamingChatLanguageModel { | |
Llama3StreamingChatLanguageModel(Path ggufPath) throws IOException { | |
super(ggufPath, 512, 0.1f, 0.95f); | |
} | |
@Override | |
public void generate(List<ChatMessage> messages, StreamingResponseHandler<AiMessage> handler) { | |
List<ChatFormat.Message> llamaMessages = messages.stream().map(Llama3ChatLanguageModel::toLlamaMessage).toList(); | |
List<Integer> promptTokens = chatFormat.encodeDialogPrompt(true, llamaMessages); | |
List<Integer> responseTokens = new ArrayList<>(); | |
Llama.generateTokens(model, model.createNewState(), 0, promptTokens, chatFormat.getStopTokens(), maxTokens, sampler, false, token -> { | |
responseTokens.add(token); | |
if (!chatFormat.getStopTokens().contains(token)) { | |
handler.onNext(model.tokenizer().decode(List.of(token))); | |
} | |
}); | |
TokenUsage tokenUsage = new TokenUsage(promptTokens.size(), responseTokens.size()); | |
FinishReason finishReason = FinishReason.LENGTH; | |
if (!responseTokens.isEmpty() && chatFormat.getStopTokens().contains(responseTokens.getLast())) { | |
finishReason = FinishReason.STOP; | |
responseTokens.removeLast(); // drop stop token from answer | |
} | |
String responseText = model.tokenizer().decode(responseTokens); | |
handler.onComplete(Response.from(AiMessage.from(responseText))); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment