forked from explosion/spaCy
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpacker.pyx
More file actions
196 lines (171 loc) · 5.9 KB
/
packer.pyx
File metadata and controls
196 lines (171 loc) · 5.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
# cython: profile=True
from __future__ import unicode_literals
from libc.stdint cimport uint32_t, int32_t
from libc.stdint cimport uint64_t
from libc.math cimport exp as c_exp
from libcpp.queue cimport priority_queue
from libcpp.pair cimport pair
from cymem.cymem cimport Address, Pool
from preshed.maps cimport PreshMap
from preshed.counter cimport PreshCounter
import json
from ..attrs cimport ORTH, ID, SPACY, TAG, HEAD, DEP, ENT_IOB, ENT_TYPE
from ..tokens.doc cimport Doc
from ..vocab cimport Vocab
from ..structs cimport LexemeC
from ..typedefs cimport attr_t
from .bits cimport BitArray
from .huffman cimport HuffmanCodec
from os import path
import numpy
from .. import util
cimport cython
# Format
# - Total number of bytes in message (32 bit int) --- handled outside this
# - Number of words (32 bit int)
# - Words, terminating in an EOL symbol, huffman coded ~12 bits per word
# - Spaces 1 bit per word
# - Attributes:
# POS tag
# Head offset
# Dep label
# Entity IOB
# Entity tag
cdef class _BinaryCodec:
def encode(self, attr_t[:] msg, BitArray bits):
cdef int i
for i in range(len(msg)):
bits.append(msg[i])
def decode(self, BitArray bits, attr_t[:] msg):
cdef int i = 0
for bit in bits:
msg[i] = bit
i += 1
if i == len(msg):
break
def _gen_orths(Vocab vocab):
cdef attr_t orth
cdef size_t addr
for orth, addr in vocab._by_orth.items():
lex = <LexemeC*>addr
yield orth, c_exp(lex.prob)
def _gen_chars(Vocab vocab):
cdef attr_t orth
cdef size_t addr
char_weights = {i: 1e-20 for i in range(256)}
cdef unicode string
cdef bytes char
cdef bytes utf8_str
for orth, addr in vocab._by_orth.items():
lex = <LexemeC*>addr
string = vocab.strings[lex.orth]
utf8_str = string.encode('utf8')
for char in utf8_str:
char_weights.setdefault(ord(char), 0.0)
char_weights[ord(char)] += c_exp(lex.prob)
char_weights[ord(' ')] += c_exp(lex.prob)
return char_weights.items()
cdef class Packer:
def __init__(self, Vocab vocab, attr_freqs, char_freqs=None):
if char_freqs is None:
char_freqs = _gen_chars(vocab)
self.vocab = vocab
self.orth_codec = HuffmanCodec(_gen_orths(vocab))
self.char_codec = HuffmanCodec(char_freqs)
codecs = []
attrs = []
for attr, freqs in sorted(attr_freqs):
if attr in (ORTH, ID, SPACY):
continue
codecs.append(HuffmanCodec(freqs))
attrs.append(attr)
self._codecs = tuple(codecs)
self.attrs = tuple(attrs)
def pack(self, Doc doc):
bits = self._orth_encode(doc)
if bits is None:
bits = self._char_encode(doc)
cdef int i
if self.attrs:
array = doc.to_array(self.attrs)
for i, codec in enumerate(self._codecs):
codec.encode(array[:, i], bits)
return bits.as_bytes()
def unpack(self, data):
doc = Doc(self.vocab)
self.unpack_into(data, doc)
return doc
def unpack_into(self, byte_string, Doc doc):
bits = BitArray(byte_string)
bits.seek(0)
cdef int32_t length = bits.read32()
if length >= 0:
self._orth_decode(bits, length, doc)
else:
self._char_decode(bits, -length, doc)
array = numpy.zeros(shape=(len(doc), len(self._codecs)), dtype=numpy.int32)
for i, codec in enumerate(self._codecs):
codec.decode(bits, array[:, i])
doc.from_array(self.attrs, array)
return doc
def _orth_encode(self, Doc doc):
for t in doc:
if t.is_oov:
return None
cdef BitArray bits = BitArray()
cdef int32_t length = len(doc)
bits.extend(length, 32)
orths = doc.to_array([ORTH])
n_bits = self.orth_codec.encode_int32(orths[:, 0], bits)
if n_bits == 0:
return None
for token in doc:
bits.append(bool(token.whitespace_))
return bits
def _char_encode(self, Doc doc):
cdef bytes utf8_str = doc.string.encode('utf8')
cdef BitArray bits = BitArray()
cdef int32_t length = len(utf8_str)
# Signal chars with negative length
bits.extend(-length, 32)
self.char_codec.encode(bytearray(utf8_str), bits)
cdef int i, j
for i in range(doc.length):
for j in range(doc.c[i].lex.length-1):
bits.append(False)
bits.append(True)
if doc.c[i].spacy:
bits.append(False)
return bits
def _orth_decode(self, BitArray bits, int32_t n, Doc doc):
cdef attr_t[:] orths = numpy.ndarray(shape=(n,), dtype=numpy.int32)
self.orth_codec.decode_int32(bits, orths)
cdef int i
cdef bint space
spaces = iter(bits)
for i in range(n):
orth = orths[i]
space = next(spaces)
lex = self.vocab.get_by_orth(doc.mem, orth)
doc.push_back(lex, space)
return doc
def _char_decode(self, BitArray bits, int32_t n_bytes, Doc doc):
cdef bytearray utf8_str = bytearray(n_bytes)
self.char_codec.decode(bits, utf8_str)
cdef unicode string = utf8_str.decode('utf8')
cdef int start = 0
cdef bint is_spacy
cdef int n_unicode_chars = len(string)
cdef int i = 0
cdef bint is_end_token
for is_end_token in bits:
if is_end_token:
span = string[start:i+1]
lex = self.vocab.get(doc.mem, span)
is_spacy = (i+1) < n_unicode_chars and string[i+1] == u' '
doc.push_back(lex, is_spacy)
start = i + 1 + is_spacy
i += 1
if i >= n_unicode_chars:
break
return doc