diff --git a/.github/workflows/build-and-release.yaml b/.github/workflows/build-and-release.yaml index 6cbac0cb1..df6201ee7 100644 --- a/.github/workflows/build-and-release.yaml +++ b/.github/workflows/build-and-release.yaml @@ -48,10 +48,16 @@ jobs: CIBW_REPAIR_WHEEL_COMMAND: "" # Linux needs auditwheel repair so manylinux and musllinux wheels are # published with distinct platform tags instead of generic linux tags. - CIBW_REPAIR_WHEEL_COMMAND_LINUX: "auditwheel repair -w {dest_dir} {wheel}" + CIBW_REPAIR_WHEEL_COMMAND_LINUX: "LD_LIBRARY_PATH=/project/llama_cpp/lib auditwheel repair -w {dest_dir} {wheel}" + # The release wheel is tagged py3-none, so one build per platform + # covers all supported Python versions and avoids duplicate names. + CIBW_BUILD_LINUX: "cp38-*" + CIBW_BUILD_MACOS: "cp39-*" + CIBW_BUILD_WINDOWS: "cp39-*" # Skip cibuildwheel's default i686 sidecar and keep Linux release # wheels on a portable x86_64 CPU baseline. CIBW_ARCHS_LINUX: "auto64" + CIBW_ARCHS_WINDOWS: "AMD64" CIBW_ENVIRONMENT_LINUX: CMAKE_ARGS="-DGGML_NATIVE=off" # Keep macOS release wheels on a portable CPU baseline instead of # inheriting the hosted runner's native flags. @@ -82,7 +88,9 @@ jobs: # Keep native arm64 builds on a portable CPU baseline instead of # tuning wheels to the hosted runner. CIBW_ENVIRONMENT: CMAKE_ARGS="-DGGML_NATIVE=off" - CIBW_BUILD: "cp38-* cp39-* cp310-* cp311-* cp312-*" + # The release wheel is tagged py3-none, so one build covers all + # supported Python versions and avoids duplicate wheel names. + CIBW_BUILD: "cp38-*" with: output-dir: wheelhouse diff --git a/.github/workflows/build-wheels-cuda.yaml b/.github/workflows/build-wheels-cuda.yaml index 17daaa12a..be55bf483 100644 --- a/.github/workflows/build-wheels-cuda.yaml +++ b/.github/workflows/build-wheels-cuda.yaml @@ -20,10 +20,17 @@ jobs: id: set-matrix run: | $matrix = @{ - 'os' = @('ubuntu-22.04') #, 'windows-2022') - 'pyver' = @("3.9", "3.10", "3.11", "3.12") - 'cuda' = @("12.1.1", "12.2.2", "12.3.2", "12.4.1") #, "12.5.1", "12.6.1") + 'os' = @('ubuntu-22.04', 'windows-2022') + # wheel.py-api = "py3" makes the CUDA wheel interpreter-agnostic, + # so one builder per toolkit version is sufficient. + 'pyver' = @("3.9") + 'cuda' = @("12.1.1", "12.2.2", "12.3.2", "12.4.1", "12.5.1") 'releasetag' = @("basic") + 'exclude' = @( + @{ 'os' = 'windows-2022'; 'cuda' = '12.1.1' }, + @{ 'os' = 'windows-2022'; 'cuda' = '12.2.2' }, + @{ 'os' = 'windows-2022'; 'cuda' = '12.3.2' } + ) } $matrixOut = ConvertTo-Json $matrix -Compress @@ -43,11 +50,11 @@ jobs: AVXVER: ${{ matrix.releasetag }} steps: - - name: Add MSBuild to PATH + - name: Set up MSVC if: runner.os == 'Windows' - uses: microsoft/setup-msbuild@v2 + uses: ilammy/msvc-dev-cmd@v1 with: - vs-version: '[16.11,16.12)' + arch: x64 - uses: actions/checkout@v4 with: @@ -67,32 +74,6 @@ jobs: add-pip-as-python-dependency: true auto-activate-base: false - - name: VS Integration Cache - id: vs-integration-cache - if: runner.os == 'Windows' - uses: actions/cache@v4 - with: - path: ./MSBuildExtensions - key: cuda-${{ matrix.cuda }}-vs-integration - - - name: Get Visual Studio Integration - if: runner.os == 'Windows' && steps.vs-integration-cache.outputs.cache-hit != 'true' - run: | - if ($env:CUDAVER -eq '12.1.1') {$x = '12.1.0'} else {$x = $env:CUDAVER} - $links = (Invoke-RestMethod 'https://raw.githubusercontent.com/Jimver/cuda-toolkit/master/src/links/windows-links.ts').Trim().split().where({$_ -ne ''}) - for ($i=$q=0;$i -lt $links.count -and $q -lt 2;$i++) {if ($links[$i] -eq "'$x',") {$q++}} - Invoke-RestMethod $links[$i].Trim("'") -OutFile 'cudainstaller.zip' - & 'C:\Program Files\7-Zip\7z.exe' e cudainstaller.zip -oMSBuildExtensions -r *\MSBuildExtensions\* > $null - Remove-Item 'cudainstaller.zip' - - - name: Install Visual Studio Integration - if: runner.os == 'Windows' - run: | - $y = (gi '.\MSBuildExtensions').fullname + '\*' - (gi 'C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\MSBuild\Microsoft\VC\*\BuildCustomizations').fullname.foreach({cp $y $_}) - $cupath = 'CUDA_PATH_V' + $env:CUDAVER.Remove($env:CUDAVER.LastIndexOf('.')).Replace('.','_') - echo "$cupath=$env:CONDA_PREFIX" >> $env:GITHUB_ENV - - name: Install Dependencies env: MAMBA_DOWNLOAD_FAILFAST: "0" @@ -101,24 +82,45 @@ jobs: $cudaVersion = $env:CUDAVER $cudaChannel = "nvidia/label/cuda-$cudaVersion" if ($IsLinux) { - # Keep nvcc, cudart, and headers on the same NVIDIA label so the - # detected toolkit version matches the published wheel tag. - mamba install -y --channel-priority flexible --override-channels -c $cudaChannel "$cudaChannel::cuda-toolkit=$cudaVersion" "$cudaChannel::cuda-nvcc_linux-64=$cudaVersion" "$cudaChannel::cuda-cudart" "$cudaChannel::cuda-cudart-dev" + mamba install -y --channel-priority flexible --override-channels -c $cudaChannel "${cudaChannel}::cuda-toolkit=$cudaVersion" "${cudaChannel}::cuda-nvcc_linux-64" "${cudaChannel}::cuda-cccl" "${cudaChannel}::cuda-cudart" "${cudaChannel}::cuda-cudart-dev" + } elseif ($IsWindows) { + if ($cudaVersion -like '12.5.*') { + # The Windows 12.5 toolkit meta-package pulls compiler activation + # scripts that overflow cmd.exe after MSVC is already initialized. + mamba install -y --channel-priority flexible --override-channels -c $cudaChannel "${cudaChannel}::cuda-nvcc_win-64" "${cudaChannel}::cuda-cccl" "${cudaChannel}::cuda-libraries-dev=$cudaVersion" "${cudaChannel}::cuda-cudart" "${cudaChannel}::cuda-cudart-dev" + } else { + mamba install -y --channel-priority flexible --override-channels -c $cudaChannel "${cudaChannel}::cuda-toolkit=$cudaVersion" "${cudaChannel}::cuda-nvcc_win-64" "${cudaChannel}::cuda-cccl" "${cudaChannel}::cuda-cudart" "${cudaChannel}::cuda-cudart-dev" + } } else { - mamba install -y --channel-priority flexible --override-channels -c $cudaChannel "$cudaChannel::cuda-toolkit=$cudaVersion" + throw 'Unsupported CUDA wheel build platform' } if ($LASTEXITCODE -ne 0) { exit $LASTEXITCODE } - python -m pip install build wheel + if ($IsWindows) { + python -m pip install build wheel ninja + } else { + python -m pip install build wheel + } - name: Build Wheel run: | - $env:CUDA_PATH = $env:CONDA_PREFIX - $env:CUDA_HOME = $env:CONDA_PREFIX - $env:CUDA_TOOLKIT_ROOT_DIR = $env:CONDA_PREFIX + $pathSeparator = if ($IsWindows) { ';' } else { ':' } + if ($IsWindows) { + $cudaRoot = Join-Path $env:CONDA_PREFIX 'Library' + } elseif (Test-Path (Join-Path $env:CONDA_PREFIX 'targets/x86_64-linux/include/cuda_runtime.h')) { + $cudaRoot = Join-Path $env:CONDA_PREFIX 'targets/x86_64-linux' + } else { + $cudaRoot = $env:CONDA_PREFIX + } + + $env:CUDA_PATH = $cudaRoot + $env:CUDA_HOME = $cudaRoot + $env:CUDAToolkit_ROOT = $cudaRoot + $env:CUDA_TOOLKIT_ROOT_DIR = $cudaRoot $cudaHostCompilerArg = '' - $env:CMAKE_ARGS = '' + $cudaRootCmake = $cudaRoot.Replace('\', '/') + $env:CMAKE_ARGS = "-DCUDAToolkit_ROOT=$cudaRootCmake -DCUDA_TOOLKIT_ROOT_DIR=$cudaRootCmake" if ($IsLinux) { if (Test-Path '/usr/bin/g++-12') { $env:CC = '/usr/bin/gcc-12' @@ -126,27 +128,41 @@ jobs: $env:CUDAHOSTCXX = '/usr/bin/g++-12' $cudaHostCompilerArg = " -DCMAKE_CUDA_HOST_COMPILER=$env:CUDAHOSTCXX" } - if (Test-Path (Join-Path $env:CONDA_PREFIX 'include/cuda_runtime.h')) { - $env:CUDAToolkit_ROOT = $env:CONDA_PREFIX - $env:CUDA_TOOLKIT_ROOT_DIR = $env:CONDA_PREFIX - $env:CMAKE_ARGS = "-DCUDAToolkit_ROOT=$env:CONDA_PREFIX -DCUDA_TOOLKIT_ROOT_DIR=$env:CONDA_PREFIX$cudaHostCompilerArg" - $env:CPATH = "$env:CONDA_PREFIX/include:$env:CPATH" - $env:CPLUS_INCLUDE_PATH = "$env:CONDA_PREFIX/include:$env:CPLUS_INCLUDE_PATH" - $env:LIBRARY_PATH = "$env:CONDA_PREFIX/lib:$env:LIBRARY_PATH" - $env:LD_LIBRARY_PATH = "$env:CONDA_PREFIX/lib:$env:LD_LIBRARY_PATH" - } else { - $env:CMAKE_ARGS = $cudaHostCompilerArg.Trim() - } + $env:CMAKE_ARGS = "-DCUDAToolkit_ROOT=$cudaRoot -DCUDA_TOOLKIT_ROOT_DIR=$cudaRoot$cudaHostCompilerArg" + $env:CPATH = "$cudaRoot/include$pathSeparator$env:CPATH" + $env:CPLUS_INCLUDE_PATH = "$cudaRoot/include$pathSeparator$env:CPLUS_INCLUDE_PATH" + $env:LIBRARY_PATH = "$cudaRoot/lib$pathSeparator$env:CONDA_PREFIX/lib$pathSeparator$env:LIBRARY_PATH" + $env:LD_LIBRARY_PATH = "$cudaRoot/lib$pathSeparator$env:CONDA_PREFIX/lib$pathSeparator$env:LD_LIBRARY_PATH" + } elseif ($IsWindows) { + $ninjaPath = ((Get-Command ninja -ErrorAction Stop).Source).Replace('\', '/') + $env:CMAKE_GENERATOR = 'Ninja' + $env:CMAKE_MAKE_PROGRAM = $ninjaPath + $env:PATH = "$(Join-Path $cudaRoot 'bin')$pathSeparator$env:PATH" } - $nvccPath = Join-Path $env:CONDA_PREFIX 'bin/nvcc' - if (-not (Test-Path $nvccPath)) { - $nvccPath = Join-Path $env:CONDA_PREFIX 'targets/x86_64-linux/bin/nvcc' + + if ($IsWindows) { + $nvccCandidates = @( + (Join-Path $cudaRoot 'bin\nvcc.exe'), + (Join-Path $env:CONDA_PREFIX 'Library\bin\nvcc.exe'), + (Join-Path $env:CONDA_PREFIX 'bin\nvcc.exe') + ) + } else { + $nvccCandidates = @( + (Join-Path $env:CONDA_PREFIX 'bin/nvcc'), + (Join-Path $env:CONDA_PREFIX 'targets/x86_64-linux/bin/nvcc') + ) } - if (-not (Test-Path $nvccPath)) { + $nvccPath = $nvccCandidates | Where-Object { Test-Path $_ } | Select-Object -First 1 + if (-not $nvccPath) { throw 'Failed to find nvcc in the conda environment' } $env:CUDACXX = $nvccPath - $env:PATH = "$(Split-Path $nvccPath):$env:PATH" + $env:PATH = "$(Split-Path $nvccPath)$pathSeparator$env:PATH" + if ($IsWindows) { + $nvccPathCmake = $nvccPath.Replace('\', '/') + $env:CUDACXX = $nvccPathCmake + $env:CMAKE_ARGS = "-DCMAKE_CUDA_COMPILER=$nvccPathCmake -DCMAKE_CUDA_COMPILER_ARG1=-allow-unsupported-compiler -DCMAKE_MAKE_PROGRAM=$env:CMAKE_MAKE_PROGRAM $env:CMAKE_ARGS" + } $nvccVersion = ((& $nvccPath --version) | Select-String 'release ([0-9]+\.[0-9]+)').Matches[0].Groups[1].Value if (-not $nvccVersion) { throw 'Failed to detect the installed CUDA toolkit version' @@ -156,16 +172,8 @@ jobs: # Build real cubins for the supported GPUs, including sm_70, and keep # one forward-compatible PTX target instead of embedding PTX for every # SM. This keeps the wheel under GitHub's 2 GiB release-asset limit. - $env:CMAKE_ARGS = "-DGGML_CUDA_FORCE_MMQ=ON -DGGML_CUDA=on -DCMAKE_CUDA_ARCHITECTURES=70-real;75-real;80-real;86-real;89-real;90-real;90-virtual -DCMAKE_CUDA_FLAGS=--allow-unsupported-compiler $env:CMAKE_ARGS" - # if ($env:AVXVER -eq 'AVX') { + $env:CMAKE_ARGS = "-DGGML_CUDA_FORCE_MMQ=ON -DGGML_CUDA=on -DCMAKE_CUDA_ARCHITECTURES=70-real;75-real;80-real;86-real;89-real;90-real;90-virtual -DCMAKE_CUDA_FLAGS=-allow-unsupported-compiler -DCMAKE_CUDA_FLAGS_INIT=-allow-unsupported-compiler $env:CMAKE_ARGS" $env:CMAKE_ARGS = $env:CMAKE_ARGS + ' -DGGML_AVX2=off -DGGML_FMA=off -DGGML_F16C=off' - # } - # if ($env:AVXVER -eq 'AVX512') { - # $env:CMAKE_ARGS = $env:CMAKE_ARGS + ' -DGGML_AVX512=on' - # } - # if ($env:AVXVER -eq 'basic') { - # $env:CMAKE_ARGS = $env:CMAKE_ARGS + ' -DGGML_AVX=off -DGGML_AVX2=off -DGGML_FMA=off -DGGML_F16C=off' - # } python -m build --wheel # Publish tags that reflect the actual installed toolkit version. Write-Output "CUDA_VERSION=$cudaTagVersion" >> $env:GITHUB_ENV diff --git a/.gitmodules b/.gitmodules index 7edf0975d..f56cca32d 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,3 @@ [submodule "vendor/llama.cpp"] path = vendor/llama.cpp - url = https://github.com/ggerganov/llama.cpp.git + url = https://github.com/ggml-org/llama.cpp.git diff --git a/CHANGELOG.md b/CHANGELOG.md index fbe5b6b6f..18c6af161 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,7 +7,26 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] -- feat: Update llama.cpp to ggerganov/llama.cpp@3bd9aa1f9 and sync Python bindings +- feat: Update llama.cpp to ggml-org/llama.cpp@b9a2170fc and sync Python bindings +- chore: Migrate llama.cpp submodule URL to ggml-org/llama.cpp by @shalinib-ibm in #2034 +- fix: Enable unified KV cache for embedding contexts to preserve full per-sequence context in batch embedding calls by @SanjanaB123 in #2217 + +## [0.3.23] + +- feat: Update llama.cpp to ggerganov/llama.cpp@7d442abf +- fix: Correct batched embedding outputs for multi-sequence `embed()` calls by @Anai-Guo in #2205 +- fix: Configure embedding contexts with enough sequence slots for batched `embed()` calls +- fix: Mark all embedding input tokens as outputs to avoid llama.cpp override warnings by @Anai-Guo in #2212 + +## [0.3.22] + +- feat: Update llama.cpp to ggerganov/llama.cpp@63d93d173 +- feat(ci): Re-enable Windows CUDA wheels and add CUDA 12.5.1 wheel builds + +## [0.3.21] + +- feat: Update llama.cpp to ggerganov/llama.cpp@f53577432 and sync Python bindings +- fix(ci): Build one arm64 release wheel for `py3-none` wheel publishing ## [0.3.20] diff --git a/llama_cpp/__init__.py b/llama_cpp/__init__.py index 83177c065..eb37da209 100644 --- a/llama_cpp/__init__.py +++ b/llama_cpp/__init__.py @@ -1,4 +1,4 @@ from .llama_cpp import * from .llama import * -__version__ = "0.3.20" +__version__ = "0.3.23" diff --git a/llama_cpp/_internals.py b/llama_cpp/_internals.py index cde52c8c8..24f6fddc7 100644 --- a/llama_cpp/_internals.py +++ b/llama_cpp/_internals.py @@ -522,7 +522,7 @@ def add_sequence(self, batch: Sequence[int], seq_id: int, logits_all: bool): self.batch.seq_id[j][0] = seq_id self.batch.n_seq_id[j] = 1 self.batch.logits[j] = logits_all - self.batch.logits[n_tokens - 1] = True + self.batch.logits[n_tokens0 + n_tokens - 1] = True class LlamaTokenDataArray: diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 11fe169cf..75c74b41f 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -397,6 +397,12 @@ def __init__( self.context_params.n_batch = self.n_batch self.context_params.n_ubatch = min(self.n_batch, n_ubatch) + if embedding: + self.context_params.n_seq_max = min( + self.n_batch, + llama_cpp.llama_max_parallel_sequences(), + ) + self.context_params.kv_unified = True self._ctx = self._stack.enter_context( contextlib.closing( internals.LlamaContext( @@ -1030,10 +1036,17 @@ def embed( """ n_embd = self.n_embd() n_batch = self.n_batch + n_seq_max = self.context_params.n_seq_max # get pooling information pooling_type = self.pooling_type() - logits_all = pooling_type == llama_cpp.LLAMA_POOLING_TYPE_NONE + # In embedding mode every input token must be marked as an output, regardless of + # pooling type. llama.cpp would otherwise override per-token `logits[i]` and emit + # "embeddings required but some input tokens were not marked as outputs -> + # overriding" once per input. Pooling NONE vs MEAN/CLS only changes how the + # per-token outputs are read back (see decode_batch below), not whether they are + # produced. See abetlen/llama-cpp-python#2208. + logits_all = True if self.context_params.embeddings is False: raise RuntimeError( @@ -1104,7 +1117,7 @@ def decode_batch(seq_sizes: List[int]): ) # time to eval batch - if t_batch + n_tokens > n_batch: + if t_batch + n_tokens > n_batch or p_batch >= n_seq_max: decode_batch(s_batch) s_batch = [] t_batch = 0 diff --git a/llama_cpp/llama_cpp.py b/llama_cpp/llama_cpp.py index 0a66a5d85..6560b5178 100644 --- a/llama_cpp/llama_cpp.py +++ b/llama_cpp/llama_cpp.py @@ -199,6 +199,8 @@ def _warn_deprecated(symbol: str, hint: str) -> None: llama_token_p = ctypes.POINTER(llama_token) # typedef int32_t llama_seq_id; llama_seq_id = ctypes.c_int32 +# typedef uint32_t llama_state_seq_flags; +llama_state_seq_flags = ctypes.c_uint32 # enum llama_vocab_type { @@ -503,13 +505,23 @@ def _warn_deprecated(symbol: str, hint: str) -> None: # enum llama_split_mode { -# LLAMA_SPLIT_MODE_NONE = 0, // single GPU -# LLAMA_SPLIT_MODE_LAYER = 1, // split layers and KV across GPUs -# LLAMA_SPLIT_MODE_ROW = 2, // split layers and KV across GPUs, use tensor parallelism if supported +# LLAMA_SPLIT_MODE_NONE = 0, // single GPU +# LLAMA_SPLIT_MODE_LAYER = 1, // split layers and KV across GPUs +# LLAMA_SPLIT_MODE_ROW = 2, // split layers and KV across GPUs, use tensor parallelism if supported +# LLAMA_SPLIT_MODE_TENSOR = 3, # }; LLAMA_SPLIT_MODE_NONE = 0 LLAMA_SPLIT_MODE_LAYER = 1 LLAMA_SPLIT_MODE_ROW = 2 +LLAMA_SPLIT_MODE_TENSOR = 3 + + +# enum llama_context_type { +# LLAMA_CONTEXT_TYPE_DEFAULT = 0, +# LLAMA_CONTEXT_TYPE_MTP = 1, +# }; +LLAMA_CONTEXT_TYPE_DEFAULT = 0 +LLAMA_CONTEXT_TYPE_MTP = 1 # typedef struct llama_token_data { @@ -890,9 +902,11 @@ class llama_sampler_seq_config(ctypes.Structure): # uint32_t n_batch; // logical maximum batch size that can be submitted to llama_decode # uint32_t n_ubatch; // physical maximum batch size # uint32_t n_seq_max; // max number of sequences (i.e. distinct states for recurrent models) +# uint32_t n_rs_seq; // number of recurrent-state snapshots per seq for rollback (0 = no rollback) [EXPERIMENTAL] # int32_t n_threads; // number of threads to use for generation # int32_t n_threads_batch; // number of threads to use for batch processing +# enum llama_context_type ctx_type; // set the context type (e.g. MTP) # enum llama_rope_scaling_type rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type` # enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id # enum llama_attention_type attention_type; // attention type to use for embeddings @@ -943,8 +957,10 @@ class llama_context_params(ctypes.Structure): n_batch (int): logical maximum batch size that can be submitted to llama_decode n_ubatch (int): physical maximum batch size n_seq_max (int): max number of sequences (i.e. distinct states for recurrent models) + n_rs_seq (int): number of recurrent-state snapshots per sequence for rollback n_threads (int): number of threads to use for generation n_threads_batch (int): number of threads to use for batch processing + ctx_type (int): context type, from `enum llama_context_type` rope_scaling_type (int): RoPE scaling type, from `enum llama_rope_scaling_type` pooling_type (int): whether to pool (sum) embedding results by sequence id (ignored if no pooling layer) attention_type (int): attention type to use for embeddings @@ -978,8 +994,10 @@ class llama_context_params(ctypes.Structure): n_batch: int n_ubatch: int n_seq_max: int + n_rs_seq: int n_threads: int n_threads_batch: int + ctx_type: int rope_scaling_type: int pooling_type: int attention_type: int @@ -1012,8 +1030,10 @@ class llama_context_params(ctypes.Structure): ("n_batch", ctypes.c_uint32), ("n_ubatch", ctypes.c_uint32), ("n_seq_max", ctypes.c_uint32), + ("n_rs_seq", ctypes.c_uint32), ("n_threads", ctypes.c_int32), ("n_threads_batch", ctypes.c_int32), + ("ctx_type", ctypes.c_int), ("rope_scaling_type", ctypes.c_int), ("pooling_type", ctypes.c_int), ("attention_type", ctypes.c_int), @@ -1514,54 +1534,6 @@ def llama_free(ctx: llama_context_p, /): ... -# enum llama_params_fit_status { -# LLAMA_PARAMS_FIT_STATUS_SUCCESS = 0, -# LLAMA_PARAMS_FIT_STATUS_FAILURE = 1, -# LLAMA_PARAMS_FIT_STATUS_ERROR = 2, -# }; -LLAMA_PARAMS_FIT_STATUS_SUCCESS = 0 -LLAMA_PARAMS_FIT_STATUS_FAILURE = 1 -LLAMA_PARAMS_FIT_STATUS_ERROR = 2 - - -# LLAMA_API enum llama_params_fit_status llama_params_fit( -# const char * path_model, -# struct llama_model_params * mparams, -# struct llama_context_params * cparams, -# float * tensor_split, -# struct llama_model_tensor_buft_override * tensor_buft_overrides, -# size_t * margins, -# uint32_t n_ctx_min, -# enum ggml_log_level log_level); -@ctypes_function( - "llama_params_fit", - [ - ctypes.c_char_p, - ctypes.POINTER(llama_model_params), - ctypes.POINTER(llama_context_params), - ctypes.POINTER(ctypes.c_float), - ctypes.c_void_p, - ctypes.POINTER(ctypes.c_size_t), - ctypes.c_uint32, - ctypes.c_int, - ], - ctypes.c_int, -) -def llama_params_fit( - path_model: bytes, - mparams: CtypesPointerOrRef[llama_model_params], - cparams: CtypesPointerOrRef[llama_context_params], - tensor_split: Optional[CtypesPointer[ctypes.c_float]], - tensor_buft_overrides: ctypes.c_void_p, - margins: Optional[CtypesPointer[ctypes.c_size_t]], - n_ctx_min: int, - log_level: int, - /, -) -> int: - """Fit model and context parameters for a model path.""" - ... - - # LLAMA_API int64_t llama_time_us(void); @ctypes_function( "llama_time_us", @@ -1635,6 +1607,11 @@ def llama_n_ubatch(ctx: llama_context_p, /) -> int: ... def llama_n_seq_max(ctx: llama_context_p, /) -> int: ... +# LLAMA_API uint32_t llama_n_rs_seq (const struct llama_context * ctx); +@ctypes_function("llama_n_rs_seq", [llama_context_p_ctypes], ctypes.c_uint32) +def llama_n_rs_seq(ctx: llama_context_p, /) -> int: ... + + # DEPRECATED(LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model), "use llama_model_n_ctx_train instead"); @ctypes_function("llama_n_ctx_train", [llama_model_p_ctypes], ctypes.c_int32) def llama_n_ctx_train(model: llama_model_p, /) -> int: ... @@ -2881,6 +2858,95 @@ def llama_state_seq_load_file( ) -> int: ... +# define LLAMA_STATE_SEQ_FLAGS_NONE 0 +LLAMA_STATE_SEQ_FLAGS_NONE = 0 + +# for backwards-compat +# define LLAMA_STATE_SEQ_FLAGS_SWA_ONLY 1 +LLAMA_STATE_SEQ_FLAGS_SWA_ONLY = 1 + +# work only with partial states, such as SWA KV cache or recurrent cache +# (e.g. Mamba) +# define LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY 1 +LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY = 1 + +# keeps the tensor data on device buffers +# (i.e. not accessible in host memory, but faster save/load) +# define LLAMA_STATE_SEQ_FLAGS_ON_DEVICE 2 +LLAMA_STATE_SEQ_FLAGS_ON_DEVICE = 2 + + +# LLAMA_API size_t llama_state_seq_get_size_ext( +# struct llama_context * ctx, +# llama_seq_id seq_id, +# llama_state_seq_flags flags); +@ctypes_function( + "llama_state_seq_get_size_ext", + [llama_context_p_ctypes, llama_seq_id, llama_state_seq_flags], + ctypes.c_size_t, +) +def llama_state_seq_get_size_ext( + ctx: llama_context_p, + seq_id: llama_seq_id, + flags: llama_state_seq_flags, + /, +) -> int: ... + + +# LLAMA_API size_t llama_state_seq_get_data_ext( +# struct llama_context * ctx, +# uint8_t * dst, +# size_t size, +# llama_seq_id seq_id, +# llama_state_seq_flags flags); +@ctypes_function( + "llama_state_seq_get_data_ext", + [ + llama_context_p_ctypes, + ctypes.POINTER(ctypes.c_uint8), + ctypes.c_size_t, + llama_seq_id, + llama_state_seq_flags, + ], + ctypes.c_size_t, +) +def llama_state_seq_get_data_ext( + ctx: llama_context_p, + dst: CtypesArray[ctypes.c_uint8], + size: Union[ctypes.c_size_t, int], + seq_id: llama_seq_id, + flags: llama_state_seq_flags, + /, +) -> int: ... + + +# LLAMA_API size_t llama_state_seq_set_data_ext( +# struct llama_context * ctx, +# const uint8_t * src, +# size_t size, +# llama_seq_id dest_seq_id, +# llama_state_seq_flags flags); +@ctypes_function( + "llama_state_seq_set_data_ext", + [ + llama_context_p_ctypes, + ctypes.POINTER(ctypes.c_uint8), + ctypes.c_size_t, + llama_seq_id, + llama_state_seq_flags, + ], + ctypes.c_size_t, +) +def llama_state_seq_set_data_ext( + ctx: llama_context_p, + src: CtypesArray[ctypes.c_uint8], + size: Union[ctypes.c_size_t, int], + dest_seq_id: llama_seq_id, + flags: llama_state_seq_flags, + /, +) -> int: ... + + # // # // Decoding # // @@ -4867,13 +4933,6 @@ def llama_perf_sampler_print(chain: llama_sampler_p, /): ... def llama_perf_sampler_reset(chain: llama_sampler_p, /): ... -# // print a breakdown of per-device memory use via LLAMA_LOG: -@ctypes_function("llama_memory_breakdown_print", [llama_context_p_ctypes], None) -def llama_memory_breakdown_print(ctx: llama_context_p, /): - """Print a breakdown of per-device memory use.""" - ... - - # // # // training # // diff --git a/llama_cpp/mtmd_cpp.py b/llama_cpp/mtmd_cpp.py index f28402775..f2b0ed2de 100644 --- a/llama_cpp/mtmd_cpp.py +++ b/llama_cpp/mtmd_cpp.py @@ -8,9 +8,9 @@ c_int, c_uint8, c_uint32, + c_size_t, c_float, c_void_p, - c_size_t, POINTER, _Pointer, # type: ignore Structure, @@ -123,6 +123,34 @@ class mtmd_input_text(Structure): ] +class mtmd_decoder_pos(Structure): + """Decoder attention position for M-RoPE models.""" + + _fields_ = [ + ("t", c_uint32), + ("x", c_uint32), + ("y", c_uint32), + ("z", c_uint32), + ] + + +# struct mtmd_caps { +# bool inp_vision; +# bool inp_audio; +# }; +class mtmd_caps(Structure): + """Capabilities exposed by an mmproj file.""" + + if TYPE_CHECKING: + inp_vision: bool + inp_audio: bool + + _fields_ = [ + ("inp_vision", c_bool), + ("inp_audio", c_bool), + ] + + ################################################ # mtmd.h functions ################################################ @@ -165,35 +193,41 @@ def mtmd_init_from_file( def mtmd_free(ctx: mtmd_context_p, /): ... -# MTMD_API bool mtmd_decode_use_non_causal(mtmd_context * ctx); -@ctypes_function("mtmd_decode_use_non_causal", [mtmd_context_p_ctypes], c_bool) -def mtmd_decode_use_non_causal(ctx: mtmd_context_p, /) -> bool: +# MTMD_API bool mtmd_decode_use_non_causal(const mtmd_context * ctx, const mtmd_input_chunk * chunk); +@ctypes_function( + "mtmd_decode_use_non_causal", + [mtmd_context_p_ctypes, mtmd_input_chunk_p_ctypes], + c_bool, +) +def mtmd_decode_use_non_causal( + ctx: mtmd_context_p, chunk: Optional[mtmd_input_chunk_p], / +) -> bool: """Check whether MTMD decoding uses non-causal attention.""" ... -# MTMD_API bool mtmd_decode_use_mrope(mtmd_context * ctx); +# MTMD_API bool mtmd_decode_use_mrope(const mtmd_context * ctx); @ctypes_function("mtmd_decode_use_mrope", [mtmd_context_p_ctypes], c_bool) def mtmd_decode_use_mrope(ctx: mtmd_context_p, /) -> bool: """Check whether MTMD decoding uses mRoPE.""" ... -# MTMD_API bool mtmd_support_vision(mtmd_context * ctx); +# MTMD_API bool mtmd_support_vision(const mtmd_context * ctx); @ctypes_function("mtmd_support_vision", [mtmd_context_p_ctypes], c_bool) def mtmd_support_vision(ctx: mtmd_context_p, /) -> bool: """Check whether the current model supports vision input.""" ... -# MTMD_API bool mtmd_support_audio(mtmd_context * ctx); +# MTMD_API bool mtmd_support_audio(const mtmd_context * ctx); @ctypes_function("mtmd_support_audio", [mtmd_context_p_ctypes], c_bool) def mtmd_support_audio(ctx: mtmd_context_p, /) -> bool: """Check whether MTMD supports audio.""" ... -# MTMD_API int mtmd_get_audio_sample_rate(mtmd_context * ctx); +# MTMD_API int mtmd_get_audio_sample_rate(const mtmd_context * ctx); @ctypes_function("mtmd_get_audio_sample_rate", [mtmd_context_p_ctypes], c_int) def mtmd_get_audio_sample_rate(ctx: mtmd_context_p, /) -> int: """Get the audio sample rate in Hz. Returns -1 if audio is not supported.""" @@ -242,6 +276,55 @@ def mtmd_bitmap_init_from_audio( def mtmd_bitmap_free(bitmap: mtmd_bitmap_p, /): ... +# MTMD_API uint32_t mtmd_bitmap_get_nx(const mtmd_bitmap * bitmap); +@ctypes_function("mtmd_bitmap_get_nx", [mtmd_bitmap_p_ctypes], c_uint32) +def mtmd_bitmap_get_nx(bitmap: mtmd_bitmap_p, /) -> int: + """Get the bitmap width in pixels.""" + ... + + +# MTMD_API uint32_t mtmd_bitmap_get_ny(const mtmd_bitmap * bitmap); +@ctypes_function("mtmd_bitmap_get_ny", [mtmd_bitmap_p_ctypes], c_uint32) +def mtmd_bitmap_get_ny(bitmap: mtmd_bitmap_p, /) -> int: + """Get the bitmap height in pixels.""" + ... + + +# MTMD_API const unsigned char * mtmd_bitmap_get_data(const mtmd_bitmap * bitmap); +@ctypes_function("mtmd_bitmap_get_data", [mtmd_bitmap_p_ctypes], POINTER(c_uint8)) +def mtmd_bitmap_get_data(bitmap: mtmd_bitmap_p, /) -> Optional[CtypesArray[c_uint8]]: + """Get the raw bitmap data buffer.""" + ... + + +# MTMD_API size_t mtmd_bitmap_get_n_bytes(const mtmd_bitmap * bitmap); +@ctypes_function("mtmd_bitmap_get_n_bytes", [mtmd_bitmap_p_ctypes], c_size_t) +def mtmd_bitmap_get_n_bytes(bitmap: mtmd_bitmap_p, /) -> int: + """Get the bitmap data size in bytes.""" + ... + + +# MTMD_API bool mtmd_bitmap_is_audio(const mtmd_bitmap * bitmap); +@ctypes_function("mtmd_bitmap_is_audio", [mtmd_bitmap_p_ctypes], c_bool) +def mtmd_bitmap_is_audio(bitmap: mtmd_bitmap_p, /) -> bool: + """Check whether the bitmap contains audio data.""" + ... + + +# MTMD_API const char * mtmd_bitmap_get_id(const mtmd_bitmap * bitmap); +@ctypes_function("mtmd_bitmap_get_id", [mtmd_bitmap_p_ctypes], c_char_p) +def mtmd_bitmap_get_id(bitmap: mtmd_bitmap_p, /) -> Optional[bytes]: + """Get the optional bitmap identifier.""" + ... + + +# MTMD_API void mtmd_bitmap_set_id(mtmd_bitmap * bitmap, const char * id); +@ctypes_function("mtmd_bitmap_set_id", [mtmd_bitmap_p_ctypes, c_char_p], None) +def mtmd_bitmap_set_id(bitmap: mtmd_bitmap_p, id: Optional[bytes], /): + """Set the optional bitmap identifier.""" + ... + + # MTMD_API mtmd_input_chunks * mtmd_input_chunks_init(void); @ctypes_function("mtmd_input_chunks_init", [], mtmd_input_chunks_p_ctypes) def mtmd_input_chunks_init() -> Optional[mtmd_input_chunks_p]: ... @@ -315,11 +398,172 @@ def mtmd_input_chunk_get_tokens_text( ) -> Optional["_Pointer[llama_cpp.llama_token]"]: ... +# MTMD_API const mtmd_image_tokens * mtmd_input_chunk_get_tokens_image(const mtmd_input_chunk * chunk); +@ctypes_function( + "mtmd_input_chunk_get_tokens_image", + [mtmd_input_chunk_p_ctypes], + mtmd_image_tokens_p_ctypes, +) +def mtmd_input_chunk_get_tokens_image( + chunk: mtmd_input_chunk_p, / +) -> Optional[mtmd_image_tokens_p]: ... + + +# MTMD_API const char * mtmd_input_chunk_get_id(const mtmd_input_chunk * chunk); +@ctypes_function("mtmd_input_chunk_get_id", [mtmd_input_chunk_p_ctypes], c_char_p) +def mtmd_input_chunk_get_id(chunk: mtmd_input_chunk_p, /) -> Optional[bytes]: + """Get the optional chunk identifier.""" + ... + + +# MTMD_API llama_pos mtmd_input_chunk_get_n_pos(const mtmd_input_chunk * chunk); +@ctypes_function( + "mtmd_input_chunk_get_n_pos", + [mtmd_input_chunk_p_ctypes], + llama_cpp.llama_pos, +) +def mtmd_input_chunk_get_n_pos(chunk: mtmd_input_chunk_p, /) -> int: + """Get the number of positions consumed by the chunk.""" + ... + + +# MTMD_API mtmd_input_chunk * mtmd_input_chunk_copy(const mtmd_input_chunk * chunk); +@ctypes_function( + "mtmd_input_chunk_copy", [mtmd_input_chunk_p_ctypes], mtmd_input_chunk_p_ctypes +) +def mtmd_input_chunk_copy(chunk: mtmd_input_chunk_p, /) -> Optional[mtmd_input_chunk_p]: + """Copy an input chunk and transfer ownership to the caller.""" + ... + + +# MTMD_API void mtmd_input_chunk_free(mtmd_input_chunk * chunk); +@ctypes_function("mtmd_input_chunk_free", [mtmd_input_chunk_p_ctypes], None) +def mtmd_input_chunk_free(chunk: mtmd_input_chunk_p, /): + """Free an owned input chunk.""" + ... + + +# MTMD_API size_t mtmd_image_tokens_get_n_tokens(const mtmd_image_tokens * image_tokens); +@ctypes_function( + "mtmd_image_tokens_get_n_tokens", [mtmd_image_tokens_p_ctypes], c_size_t +) +def mtmd_image_tokens_get_n_tokens(image_tokens: mtmd_image_tokens_p, /) -> int: + """Get the number of image tokens.""" + ... + + +# DEPRECATED(MTMD_API size_t mtmd_image_tokens_get_nx(const mtmd_image_tokens * image_tokens), +# "use mtmd_image_tokens_get_decoder_pos() instead"); +@ctypes_function("mtmd_image_tokens_get_nx", [mtmd_image_tokens_p_ctypes], c_size_t) +def mtmd_image_tokens_get_nx(image_tokens: mtmd_image_tokens_p, /) -> int: + """Get the image token grid width.""" + ... + + +# DEPRECATED(MTMD_API size_t mtmd_image_tokens_get_ny(const mtmd_image_tokens * image_tokens), +# "use mtmd_image_tokens_get_decoder_pos() instead"); +@ctypes_function("mtmd_image_tokens_get_ny", [mtmd_image_tokens_p_ctypes], c_size_t) +def mtmd_image_tokens_get_ny(image_tokens: mtmd_image_tokens_p, /) -> int: + """Get the image token grid height.""" + ... + + +# MTMD_API const char * mtmd_image_tokens_get_id(const mtmd_image_tokens * image_tokens); +@ctypes_function("mtmd_image_tokens_get_id", [mtmd_image_tokens_p_ctypes], c_char_p) +def mtmd_image_tokens_get_id(image_tokens: mtmd_image_tokens_p, /) -> Optional[bytes]: + """Get the optional image token identifier.""" + ... + + +# MTMD_API llama_pos mtmd_image_tokens_get_n_pos(const mtmd_image_tokens * image_tokens); +@ctypes_function( + "mtmd_image_tokens_get_n_pos", + [mtmd_image_tokens_p_ctypes], + llama_cpp.llama_pos, +) +def mtmd_image_tokens_get_n_pos(image_tokens: mtmd_image_tokens_p, /) -> int: + """Get the number of positions consumed by the image tokens.""" + ... + + +# MTMD_API struct mtmd_decoder_pos mtmd_image_tokens_get_decoder_pos( +# const mtmd_image_tokens * image_tokens, llama_pos pos_0, size_t i); +@ctypes_function( + "mtmd_image_tokens_get_decoder_pos", + [mtmd_image_tokens_p_ctypes, llama_cpp.llama_pos, c_size_t], + mtmd_decoder_pos, +) +def mtmd_image_tokens_get_decoder_pos( + image_tokens: mtmd_image_tokens_p, + pos_0: llama_cpp.llama_pos, + i: Union[c_size_t, int], + /, +) -> mtmd_decoder_pos: + """Get decoder attention position for an image embedding token.""" + ... + + +# MTMD_API int32_t mtmd_encode(mtmd_context * ctx, const mtmd_image_tokens * image_tokens); +@ctypes_function( + "mtmd_encode", + [mtmd_context_p_ctypes, mtmd_image_tokens_p_ctypes], + c_int, +) +def mtmd_encode(ctx: mtmd_context_p, image_tokens: mtmd_image_tokens_p, /) -> int: + """Run an MTMD encode pass for image tokens.""" + ... + + +# MTMD_API int32_t mtmd_encode_chunk(mtmd_context * ctx, const mtmd_input_chunk * chunk); +@ctypes_function( + "mtmd_encode_chunk", + [mtmd_context_p_ctypes, mtmd_input_chunk_p_ctypes], + c_int, +) +def mtmd_encode_chunk(ctx: mtmd_context_p, chunk: mtmd_input_chunk_p, /) -> int: + """Run an MTMD encode pass for a single chunk.""" + ... + + +# MTMD_API float * mtmd_get_output_embd(mtmd_context * ctx); +@ctypes_function("mtmd_get_output_embd", [mtmd_context_p_ctypes], POINTER(c_float)) +def mtmd_get_output_embd(ctx: mtmd_context_p, /) -> Optional[CtypesArray[c_float]]: + """Get output embeddings from the last encode pass.""" + ... + + +# MTMD_API struct mtmd_caps mtmd_get_cap_from_file(const char * mmproj_fname); +@ctypes_function("mtmd_get_cap_from_file", [c_char_p], mtmd_caps) +def mtmd_get_cap_from_file(mmproj_fname: bytes, /) -> mtmd_caps: + """Get mmproj capabilities without initializing a full MTMD context.""" + ... + + +# MTMD_API mtmd_input_chunks * mtmd_test_create_input_chunks(void); +@ctypes_function("mtmd_test_create_input_chunks", [], mtmd_input_chunks_p_ctypes) +def mtmd_test_create_input_chunks() -> Optional[mtmd_input_chunks_p]: + """Create MTMD test chunks for the C API tests.""" + ... + + ################################################ # mtmd-helper.h functions ################################################ +# MTMD_API mtmd_bitmap * mtmd_helper_bitmap_init_from_file(mtmd_context * ctx, const char * fname); +@ctypes_function( + "mtmd_helper_bitmap_init_from_file", + [mtmd_context_p_ctypes, c_char_p], + mtmd_bitmap_p_ctypes, +) +def mtmd_helper_bitmap_init_from_file( + ctx: mtmd_context_p, fname: bytes, / +) -> Optional[mtmd_bitmap_p]: + """Initialize an MTMD bitmap from a file.""" + ... + + # MTMD_API mtmd_bitmap * mtmd_helper_bitmap_init_from_buf(mtmd_context * ctx, const unsigned char * buf, size_t len); @ctypes_function( "mtmd_helper_bitmap_init_from_buf", @@ -339,6 +583,69 @@ def mtmd_helper_bitmap_init_from_buf( def mtmd_helper_get_n_tokens(chunks: mtmd_input_chunks_p, /) -> int: ... +# MTMD_API llama_pos mtmd_helper_get_n_pos(const mtmd_input_chunks * chunks); +@ctypes_function( + "mtmd_helper_get_n_pos", + [mtmd_input_chunks_p_ctypes], + llama_cpp.llama_pos, +) +def mtmd_helper_get_n_pos(chunks: mtmd_input_chunks_p, /) -> int: + """Count the total positions consumed by the chunks.""" + ... + + +# MTMD_API void mtmd_helper_image_get_decoder_pos( +# const mtmd_image_tokens * image, llama_pos pos_0, struct mtmd_decoder_pos * out_pos); +@ctypes_function( + "mtmd_helper_image_get_decoder_pos", + [mtmd_image_tokens_p_ctypes, llama_cpp.llama_pos, POINTER(mtmd_decoder_pos)], + None, +) +def mtmd_helper_image_get_decoder_pos( + image: mtmd_image_tokens_p, + pos_0: llama_cpp.llama_pos, + out_pos: "_Pointer[mtmd_decoder_pos]", + /, +): + """Fill decoder attention positions for all image embedding tokens.""" + ... + + +# MTMD_API int32_t mtmd_helper_eval_chunks(mtmd_context * ctx, +# struct llama_context * lctx, +# const mtmd_input_chunks * chunks, +# llama_pos n_past, +# llama_seq_id seq_id, +# int32_t n_batch, +# bool logits_last, +# llama_pos * new_n_past); +@ctypes_function( + "mtmd_helper_eval_chunks", + [ + mtmd_context_p_ctypes, + llama_cpp.llama_context_p_ctypes, + mtmd_input_chunks_p_ctypes, + llama_cpp.llama_pos, + llama_cpp.llama_seq_id, + c_int, + c_bool, + POINTER(llama_cpp.llama_pos), + ], + c_int, +) +def mtmd_helper_eval_chunks( + ctx: mtmd_context_p, + lctx: llama_cpp.llama_context_p, + chunks: mtmd_input_chunks_p, + n_past: llama_cpp.llama_pos, + seq_id: llama_cpp.llama_seq_id, + n_batch: Union[c_int, int], + logits_last: Union[c_bool, bool], + new_n_past: "_Pointer[llama_cpp.llama_pos]", + /, +) -> int: ... + + # MTMD_API int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx, # struct llama_context * lctx, # const mtmd_input_chunk * chunk, @@ -374,6 +681,43 @@ def mtmd_helper_eval_chunk_single( ) -> int: ... +# MTMD_API int32_t mtmd_helper_decode_image_chunk(mtmd_context * ctx, +# struct llama_context * lctx, +# const mtmd_input_chunk * chunk, +# float * encoded_embd, +# llama_pos n_past, +# llama_seq_id seq_id, +# int32_t n_batch, +# llama_pos * new_n_past); +@ctypes_function( + "mtmd_helper_decode_image_chunk", + [ + mtmd_context_p_ctypes, + llama_cpp.llama_context_p_ctypes, + mtmd_input_chunk_p_ctypes, + POINTER(c_float), + llama_cpp.llama_pos, + llama_cpp.llama_seq_id, + c_int, + POINTER(llama_cpp.llama_pos), + ], + c_int, +) +def mtmd_helper_decode_image_chunk( + ctx: mtmd_context_p, + lctx: llama_cpp.llama_context_p, + chunk: mtmd_input_chunk_p, + encoded_embd: CtypesArray[c_float], + n_past: llama_cpp.llama_pos, + seq_id: llama_cpp.llama_seq_id, + n_batch: Union[c_int, int], + new_n_past: "_Pointer[llama_cpp.llama_pos]", + /, +) -> int: + """Decode a pre-encoded image chunk.""" + ... + + # MTMD_API void mtmd_log_set(ggml_log_callback log_callback, void * user_data); @ctypes_function( "mtmd_log_set", diff --git a/mkdocs.yml b/mkdocs.yml index 79a9e67a1..37e1002e8 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -36,7 +36,7 @@ plugins: - typing - typing_extensions - ctypes - import: + inventories: - https://docs.python.org/3/objects.inv - https://numpy.org/doc/stable/objects.inv diff --git a/tests/test_llama.py b/tests/test_llama.py index 23928fff6..d4e6031c7 100644 --- a/tests/test_llama.py +++ b/tests/test_llama.py @@ -247,3 +247,15 @@ def test_real_llama_embeddings(llama_cpp_embedding_model_path): ) embedding = model.embed("Hello World") assert len(embedding) > 0 + + prompts = ["Hello World", "A different prompt"] + individual_embeddings = [model.embed(prompt) for prompt in prompts] + batched_embeddings = model.embed(prompts) + + assert len(batched_embeddings) == len(prompts) + for individual, batched in zip(individual_embeddings, batched_embeddings): + np.testing.assert_allclose(batched, individual, rtol=1e-4, atol=1e-4) + + repeated_embeddings = model.embed(list(reversed(prompts))) + assert len(repeated_embeddings) == len(prompts) + assert all(len(repeated) == len(embedding) for repeated in repeated_embeddings) diff --git a/vendor/llama.cpp b/vendor/llama.cpp index 3bd9aa1f9..b9a2170fc 160000 --- a/vendor/llama.cpp +++ b/vendor/llama.cpp @@ -1 +1 @@ -Subproject commit 3bd9aa1f9250cd15f5371f3622d73d954b68a747 +Subproject commit b9a2170fce1f3f33cb4934b34efecb806bbbb348