diff --git a/.gitmodules b/.gitmodules index 7edf0975dc..4451d88384 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,3 @@ [submodule "vendor/llama.cpp"] path = vendor/llama.cpp - url = https://github.com/ggerganov/llama.cpp.git + url = https://github.com/xaptronic/llama.cpp.git diff --git a/grammar_test.py b/grammar_test.py new file mode 100644 index 0000000000..a2e436ff3f --- /dev/null +++ b/grammar_test.py @@ -0,0 +1,28 @@ +from llama_cpp import Llama + +grammar = """root ::= nav eol (commands eol)* +commands ::= t | info +nav ::= "nav(\\"admin/" [a-z/]* "\\")" +info ::= "info(" setting ")" +t ::= "t(" setting ", " value ")" +value ::= color | string | number | boolean +color ::= "#" [0-9a-f][0-9a-f][0-9a-f][0-9a-f][0-9a-f][0-9a-f] +setting ::= "\\"" [a-z ]+ "\\"" +string ::= "\\"" [ \\t!#-\\[\\]-~]* "\\"" +number ::= [0-9]+ +boolean ::= ("true" | "false") +eol ::= "\\n" +""" + +llm = Llama( + model_path="/Users/alex/llama-7b.ggmlv3.q8_0.bin", + # lora_base="/Users/alex/llama-7b.ggml.f16.bin", + # python ~/llama.cpp/convert-lora-to-ggml.py . + # lora_path="/Users/alex/src/github.com/Shopify/sidekick-data/src/webapp/models/ggml-adapter-model.bin", + # n_gpu_layers=1000, + n_ctx=2048, + grammar=grammar, +) + +import code +code.interact(local=globals()) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 4b6ce8c37b..20e6eb0835 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -219,6 +219,7 @@ def __init__( last_n_tokens_size: int = 64, lora_base: Optional[str] = None, lora_path: Optional[str] = None, + grammar: Optional[str] = None, verbose: bool = True, ): """Load a llama.cpp model from `model_path`. @@ -299,6 +300,12 @@ def __init__( f"Failed to apply LoRA from lora path: {self.lora_path} to base path: {self.lora_base}" ) + if grammar: + self.parse_state = llama_cpp.llama_grammar_parse( + llama_cpp.c_char_p(grammar.encode("utf-8")) + ) + self.grammar = llama_cpp.llama_grammar_from_state(self.parse_state) + if self.verbose: print(llama_cpp.llama_print_system_info().decode("utf-8"), file=sys.stderr) @@ -496,8 +503,16 @@ def _sample( ) if not penalize_nl: candidates.data[self._token_nl].logit = llama_cpp.c_float(nl_logit) + + if self.grammar: + llama_cpp.llama_sample_grammar( + self.ctx, + candidates=llama_cpp.ctypes.byref(candidates), + grammar=self.grammar, + ) # type: ignore + if temp.value == 0.0: - return llama_cpp.llama_sample_token_greedy( + id = llama_cpp.llama_sample_token_greedy( ctx=self.ctx, candidates=llama_cpp.ctypes.byref(candidates), # type: ignore ) @@ -509,7 +524,7 @@ def _sample( candidates=llama_cpp.ctypes.byref(candidates), # type: ignore temp=temp, ) - return llama_cpp.llama_sample_token_mirostat( + id = llama_cpp.llama_sample_token_mirostat( ctx=self.ctx, candidates=llama_cpp.ctypes.byref(candidates), # type: ignore tau=mirostat_tau, @@ -524,7 +539,7 @@ def _sample( candidates=llama_cpp.ctypes.pointer(candidates), temp=temp, ) - return llama_cpp.llama_sample_token_mirostat_v2( + id = llama_cpp.llama_sample_token_mirostat_v2( ctx=self.ctx, candidates=llama_cpp.ctypes.byref(candidates), # type: ignore tau=mirostat_tau, @@ -561,11 +576,20 @@ def _sample( candidates=llama_cpp.ctypes.byref(candidates), # type: ignore temp=temp, ) - return llama_cpp.llama_sample_token( + id = llama_cpp.llama_sample_token( ctx=self.ctx, candidates=llama_cpp.ctypes.byref(candidates), # type: ignore ) + if self.grammar: + id = llama_cpp.llama_grammar_accept_token( + self.ctx, + self.grammar, + id + ) + + return id + def sample( self, top_k: int = 40, @@ -865,7 +889,8 @@ def _create_completion( stopping_criteria=stopping_criteria, logits_processor=logits_processor, ): - if token == self._token_eos: + + if token == self._token_eos: #or token == self._token_nl: text = self.detokenize(completion_tokens) finish_reason = "stop" break diff --git a/llama_cpp/llama_cpp.py b/llama_cpp/llama_cpp.py index 29136c7e93..d5eec0959e 100644 --- a/llama_cpp/llama_cpp.py +++ b/llama_cpp/llama_cpp.py @@ -58,6 +58,7 @@ def _load_shared_library(lib_base_name: str): cdll_args["winmode"] = 0 # Try to load the shared library, handling potential errors + print(_lib_paths) for _lib_path in _lib_paths: if _lib_path.exists(): try: @@ -113,6 +114,10 @@ def _load_shared_library(lib_base_name: str): llama_token = c_int llama_token_p = POINTER(llama_token) +# struct llama_grammar +parse_state_p = c_void_p +llama_grammar_p = c_void_p + # typedef struct llama_token_data { # llama_token id; // token id @@ -793,6 +798,54 @@ def llama_sample_temperature( _lib.llama_sample_temperature.restype = None +def llama_grammar_parse(grammar: str): + return _lib.llama_grammar_parse(grammar) + +_lib.llama_grammar_parse.argtypes = [ + c_char_p, +] +_lib.llama_grammar_parse.restype = parse_state_p + + +def llama_grammar_from_state(parse_state: parse_state_p): + return _lib.llama_grammar_from_state(parse_state) + +_lib.llama_grammar_from_state.argtypes = [ + parse_state_p +] +_lib.llama_grammar_from_state.restype = llama_grammar_p + + +def llama_sample_grammar( + ctx: llama_context_p, + candidates, # type: _Pointer[llama_token_data_array] + grammar: llama_grammar_p, +): + return _lib.llama_sample_grammar(ctx, candidates, grammar) + +_lib.llama_sample_grammar.argtypes = [ + llama_context_p, + llama_token_data_array_p, + llama_grammar_p, +] +_lib.llama_sample_grammar.restype = None + + +def llama_grammar_accept_token( + ctx: llama_context_p, + grammar: llama_grammar_p, + id: llama_token, +): + return _lib.llama_grammar_accept_token(ctx, grammar, id) + +_lib.llama_grammar_accept_token.argtypes = [ + llama_context_p, + llama_grammar_p, + llama_token +] +_lib.llama_grammar_accept_token.restype = llama_token + + # @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. # @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. # @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. diff --git a/vendor/llama.cpp b/vendor/llama.cpp index 4de0334f5c..9d0fcb0c35 160000 --- a/vendor/llama.cpp +++ b/vendor/llama.cpp @@ -1 +1 @@ -Subproject commit 4de0334f5cabf4696eced2e5d6e279fdfaa6c0f2 +Subproject commit 9d0fcb0c350305a91ce7460c57228f2d259a804f