22
33import java .lang .annotation .Native ;
44import java .nio .charset .StandardCharsets ;
5- import java .util .Iterator ;
6- import java .util .Map ;
7- import java .util .NoSuchElementException ;
8-
9- import org .jetbrains .annotations .NotNull ;
105
116/**
127 * This class is a wrapper around the llama.cpp functionality.
@@ -54,7 +49,7 @@ public LlamaModel(ModelParameters parameters) {
5449 public String complete (InferenceParameters parameters ) {
5550 parameters .setStream (false );
5651 int taskId = requestCompletion (parameters .toString ());
57- Output output = receiveCompletion (taskId );
52+ LlamaOutput output = receiveCompletion (taskId );
5853 return output .text ;
5954 }
6055
@@ -64,8 +59,8 @@ public String complete(InferenceParameters parameters) {
6459 *
6560 * @return iterable LLM outputs
6661 */
67- public Iterable < Output > generate (InferenceParameters parameters ) {
68- return () -> new LlamaIterator (parameters );
62+ public LlamaIterable generate (InferenceParameters parameters ) {
63+ return () -> new LlamaIterator (this , parameters );
6964 }
7065
7166 /**
@@ -98,79 +93,22 @@ public String decode(int[] tokens) {
9893 return new String (bytes , StandardCharsets .UTF_8 );
9994 }
10095
101- // /**
102- // * Sets a callback for both Java and C++ log messages. Can be set to {@code null} to disable logging.
103- // *
104- // * @param callback a method to call for log messages
105- // */
106- // public static native void setLogger(@Nullable BiConsumer<LogLevel, String> callback);
107-
10896 @ Override
10997 public void close () {
11098 delete ();
11199 }
112100
113101 // don't overload native methods since the C++ function names get nasty
114- private native void loadModel (String parameters ) throws LlamaException ;
102+ native int requestCompletion (String params ) throws LlamaException ;
115103
116- private native int requestCompletion ( String params ) throws LlamaException ;
104+ native LlamaOutput receiveCompletion ( int taskId ) throws LlamaException ;
117105
118- private native Output receiveCompletion (int taskId ) throws LlamaException ;
106+ native void cancelCompletion (int taskId );
119107
120- private native byte [] decodeBytes (int [] tokens );
108+ native byte [] decodeBytes (int [] tokens );
121109
122- private native void delete () ;
110+ private native void loadModel ( String parameters ) throws LlamaException ;
123111
124- /**
125- * A generated output of the LLM. Note that you have to configure {@link InferenceParameters#setNProbs(int)}
126- * in order for probabilities to be returned.
127- */
128- public static final class Output {
129-
130- @ NotNull
131- public final String text ;
132- @ NotNull
133- public final Map <String , Float > probabilities ;
134- private final boolean stop ;
135-
136- private Output (byte [] generated , @ NotNull Map <String , Float > probabilities , boolean stop ) {
137- this .text = new String (generated , StandardCharsets .UTF_8 );
138- this .probabilities = probabilities ;
139- this .stop = stop ;
140- }
141-
142- @ Override
143- public String toString () {
144- return text ;
145- }
146- }
112+ private native void delete ();
147113
148- private final class LlamaIterator implements Iterator <Output > {
149-
150- private final int taskId ;
151-
152- @ Native
153- @ SuppressWarnings ("FieldMayBeFinal" )
154- private boolean hasNext = true ;
155-
156- private LlamaIterator (InferenceParameters parameters ) {
157- parameters .setStream (true );
158- taskId = requestCompletion (parameters .toString ());
159- }
160-
161- @ Override
162- public boolean hasNext () {
163- return hasNext ;
164- }
165-
166- @ Override
167- public Output next () {
168- if (!hasNext ) {
169- throw new NoSuchElementException ();
170- }
171- Output output = receiveCompletion (taskId );
172- hasNext = !output .stop ;
173- return output ;
174- }
175- }
176114}
0 commit comments