Skip to content

Commit 9cf237a

Browse files
committed
add LlamaOutput, LlamaIterable, and LlamaIterator
1 parent ea3934d commit 9cf237a

8 files changed

Lines changed: 120 additions & 77 deletions

File tree

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
package de.kherud.llama;
2+
3+
import org.jetbrains.annotations.NotNull;
4+
5+
/**
6+
* An iterable used by {@link LlamaModel#generate(InferenceParameters)} that specifically returns a {@link LlamaIterator}.
7+
*/
8+
@FunctionalInterface
9+
public interface LlamaIterable extends Iterable<LlamaOutput> {
10+
11+
@NotNull
12+
@Override
13+
LlamaIterator iterator();
14+
15+
}
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
package de.kherud.llama;
2+
3+
import java.lang.annotation.Native;
4+
import java.util.Iterator;
5+
import java.util.NoSuchElementException;
6+
7+
/**
8+
* This iterator is used by {@link LlamaModel#generate(InferenceParameters)}. In addition to implementing {@link Iterator},
9+
* it allows to cancel ongoing inference (see {@link #cancel()}).
10+
*/
11+
public final class LlamaIterator implements Iterator<LlamaOutput> {
12+
13+
private final LlamaModel model;
14+
private final int taskId;
15+
16+
@Native
17+
@SuppressWarnings("FieldMayBeFinal")
18+
private boolean hasNext = true;
19+
20+
LlamaIterator(LlamaModel model, InferenceParameters parameters) {
21+
this.model = model;
22+
parameters.setStream(true);
23+
taskId = model.requestCompletion(parameters.toString());
24+
}
25+
26+
@Override
27+
public boolean hasNext() {
28+
return hasNext;
29+
}
30+
31+
@Override
32+
public LlamaOutput next() {
33+
if (!hasNext) {
34+
throw new NoSuchElementException();
35+
}
36+
LlamaOutput output = model.receiveCompletion(taskId);
37+
hasNext = !output.stop;
38+
return output;
39+
}
40+
41+
/**
42+
* Cancel the ongoing generation process.
43+
*/
44+
public void cancel() {
45+
model.cancelCompletion(taskId);
46+
hasNext = false;
47+
}
48+
}

src/main/java/de/kherud/llama/LlamaModel.java

Lines changed: 9 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,6 @@
22

33
import java.lang.annotation.Native;
44
import java.nio.charset.StandardCharsets;
5-
import java.util.Iterator;
6-
import java.util.Map;
7-
import java.util.NoSuchElementException;
8-
9-
import org.jetbrains.annotations.NotNull;
105

116
/**
127
* This class is a wrapper around the llama.cpp functionality.
@@ -54,7 +49,7 @@ public LlamaModel(ModelParameters parameters) {
5449
public String complete(InferenceParameters parameters) {
5550
parameters.setStream(false);
5651
int taskId = requestCompletion(parameters.toString());
57-
Output output = receiveCompletion(taskId);
52+
LlamaOutput output = receiveCompletion(taskId);
5853
return output.text;
5954
}
6055

@@ -64,8 +59,8 @@ public String complete(InferenceParameters parameters) {
6459
*
6560
* @return iterable LLM outputs
6661
*/
67-
public Iterable<Output> generate(InferenceParameters parameters) {
68-
return () -> new LlamaIterator(parameters);
62+
public LlamaIterable generate(InferenceParameters parameters) {
63+
return () -> new LlamaIterator(this, parameters);
6964
}
7065

7166
/**
@@ -98,79 +93,22 @@ public String decode(int[] tokens) {
9893
return new String(bytes, StandardCharsets.UTF_8);
9994
}
10095

101-
// /**
102-
// * Sets a callback for both Java and C++ log messages. Can be set to {@code null} to disable logging.
103-
// *
104-
// * @param callback a method to call for log messages
105-
// */
106-
// public static native void setLogger(@Nullable BiConsumer<LogLevel, String> callback);
107-
10896
@Override
10997
public void close() {
11098
delete();
11199
}
112100

113101
// don't overload native methods since the C++ function names get nasty
114-
private native void loadModel(String parameters) throws LlamaException;
102+
native int requestCompletion(String params) throws LlamaException;
115103

116-
private native int requestCompletion(String params) throws LlamaException;
104+
native LlamaOutput receiveCompletion(int taskId) throws LlamaException;
117105

118-
private native Output receiveCompletion(int taskId) throws LlamaException;
106+
native void cancelCompletion(int taskId);
119107

120-
private native byte[] decodeBytes(int[] tokens);
108+
native byte[] decodeBytes(int[] tokens);
121109

122-
private native void delete();
110+
private native void loadModel(String parameters) throws LlamaException;
123111

124-
/**
125-
* A generated output of the LLM. Note that you have to configure {@link InferenceParameters#setNProbs(int)}
126-
* in order for probabilities to be returned.
127-
*/
128-
public static final class Output {
129-
130-
@NotNull
131-
public final String text;
132-
@NotNull
133-
public final Map<String, Float> probabilities;
134-
private final boolean stop;
135-
136-
private Output(byte[] generated, @NotNull Map<String, Float> probabilities, boolean stop) {
137-
this.text = new String(generated, StandardCharsets.UTF_8);
138-
this.probabilities = probabilities;
139-
this.stop = stop;
140-
}
141-
142-
@Override
143-
public String toString() {
144-
return text;
145-
}
146-
}
112+
private native void delete();
147113

148-
private final class LlamaIterator implements Iterator<Output> {
149-
150-
private final int taskId;
151-
152-
@Native
153-
@SuppressWarnings("FieldMayBeFinal")
154-
private boolean hasNext = true;
155-
156-
private LlamaIterator(InferenceParameters parameters) {
157-
parameters.setStream(true);
158-
taskId = requestCompletion(parameters.toString());
159-
}
160-
161-
@Override
162-
public boolean hasNext() {
163-
return hasNext;
164-
}
165-
166-
@Override
167-
public Output next() {
168-
if (!hasNext) {
169-
throw new NoSuchElementException();
170-
}
171-
Output output = receiveCompletion(taskId);
172-
hasNext = !output.stop;
173-
return output;
174-
}
175-
}
176114
}
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
package de.kherud.llama;
2+
3+
import org.jetbrains.annotations.NotNull;
4+
5+
import java.nio.charset.StandardCharsets;
6+
import java.util.Map;
7+
8+
/**
9+
* An output of the LLM providing access to the generated text and the associated probabilities. You have to configure
10+
* {@link InferenceParameters#setNProbs(int)} in order for probabilities to be returned.
11+
*/
12+
public final class LlamaOutput {
13+
14+
/**
15+
* The last bit of generated text that is representable as text (i.e., cannot be individual utf-8 multibyte code
16+
* points).
17+
*/
18+
@NotNull
19+
public final String text;
20+
21+
/**
22+
* Note, that you have to configure {@link InferenceParameters#setNProbs(int)} in order for probabilities to be returned.
23+
*/
24+
@NotNull
25+
public final Map<String, Float> probabilities;
26+
27+
final boolean stop;
28+
29+
LlamaOutput(byte[] generated, @NotNull Map<String, Float> probabilities, boolean stop) {
30+
this.text = new String(generated, StandardCharsets.UTF_8);
31+
this.probabilities = probabilities;
32+
this.stop = stop;
33+
}
34+
35+
@Override
36+
public String toString() {
37+
return text;
38+
}
39+
}

src/test/java/de/kherud/llama/LlamaModelTest.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ public void testGenerateAnswer() {
4545
.setTokenIdBias(logitBias);
4646

4747
int generated = 0;
48-
for (LlamaModel.Output ignored : model.generate(params)) {
48+
for (LlamaOutput ignored : model.generate(params)) {
4949
generated++;
5050
}
5151
// todo: currently, after generating nPredict tokens, there is an additional empty output
@@ -66,7 +66,7 @@ public void testGenerateInfill() {
6666
.setSeed(42);
6767

6868
int generated = 0;
69-
for (LlamaModel.Output ignored : model.generate(params)) {
69+
for (LlamaOutput ignored : model.generate(params)) {
7070
generated++;
7171
}
7272
Assert.assertTrue(generated > 0 && generated <= nPredict + 1);
@@ -78,7 +78,7 @@ public void testGenerateGrammar() {
7878
.setGrammar("root ::= (\"a\" | \"b\")+")
7979
.setNPredict(nPredict);
8080
StringBuilder sb = new StringBuilder();
81-
for (LlamaModel.Output output : model.generate(params)) {
81+
for (LlamaOutput output : model.generate(params)) {
8282
sb.append(output);
8383
}
8484
String output = sb.toString();

src/test/java/examples/GrammarExample.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package examples;
22

3+
import de.kherud.llama.LlamaOutput;
34
import de.kherud.llama.ModelParameters;
45

56
import de.kherud.llama.InferenceParameters;
@@ -16,7 +17,7 @@ public static void main(String... args) {
1617
InferenceParameters inferParams = new InferenceParameters("")
1718
.setGrammar(grammar);
1819
try (LlamaModel model = new LlamaModel(modelParams)) {
19-
for (LlamaModel.Output output : model.generate(inferParams)) {
20+
for (LlamaOutput output : model.generate(inferParams)) {
2021
System.out.print(output);
2122
}
2223
}

src/test/java/examples/InfillExample.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import de.kherud.llama.InferenceParameters;
44
import de.kherud.llama.LlamaModel;
5+
import de.kherud.llama.LlamaOutput;
56
import de.kherud.llama.ModelParameters;
67

78
public class InfillExample {
@@ -18,7 +19,7 @@ public static void main(String... args) {
1819
InferenceParameters inferParams = new InferenceParameters("")
1920
.setInputPrefix(prefix)
2021
.setInputSuffix(suffix);
21-
for (LlamaModel.Output output : model.generate(inferParams)) {
22+
for (LlamaOutput output : model.generate(inferParams)) {
2223
System.out.print(output);
2324
}
2425
System.out.print(suffix);

src/test/java/examples/MainExample.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import de.kherud.llama.InferenceParameters;
99
import de.kherud.llama.LlamaModel;
10+
import de.kherud.llama.LlamaOutput;
1011
import de.kherud.llama.ModelParameters;
1112
import de.kherud.llama.args.MiroStat;
1213

@@ -39,7 +40,7 @@ public static void main(String... args) throws IOException {
3940
.setPenalizeNl(true)
4041
.setMiroStat(MiroStat.V2)
4142
.setStopStrings("User:");
42-
for (LlamaModel.Output output : model.generate(inferParams)) {
43+
for (LlamaOutput output : model.generate(inferParams)) {
4344
System.out.print(output);
4445
prompt += output;
4546
}

0 commit comments

Comments
 (0)