@@ -323,15 +323,27 @@ def __init__(self, fname_tokenizer: Path, fname_added_tokens: Path | None) -> No
323
323
self .bpe_tokenizer = json .loads (open (str (fname_tokenizer ), encoding = "utf-8" ).read ())
324
324
added_tokens : dict [str , int ]
325
325
if fname_added_tokens is not None :
326
+ # FIXME: Verify that added tokens here _cannot_ overlap with the main vocab.
326
327
added_tokens = json .load (open (fname_added_tokens , encoding = "utf-8" ))
327
328
else :
328
- added_tokens = {}
329
+ # Fall back to trying to find the added tokens in tokenizer.json
330
+ tokenizer_json_file = fname_tokenizer .parent / 'tokenizer.json'
331
+ if not tokenizer_json_file .is_file ():
332
+ added_tokens = {}
333
+ else :
334
+ tokenizer_json = json .load (open (tokenizer_json_file , encoding = "utf-8" ))
335
+ added_tokens = dict (
336
+ (item ['content' ], item ['id' ])
337
+ for item in tokenizer_json .get ('added_tokens' , [])
338
+ # Added tokens here can be duplicates of the main vocabulary.
339
+ if item ['content' ] not in self .bpe_tokenizer )
329
340
330
341
vocab_size : int = len (self .bpe_tokenizer )
331
342
expected_ids = list (range (vocab_size , vocab_size + len (added_tokens )))
332
343
actual_ids = sorted (added_tokens .values ())
333
344
if expected_ids != actual_ids :
334
- raise Exception (f"Expected added token IDs to be sequential and start at { len (added_tokens )} ; got { actual_ids } " )
345
+ expected_end_id = vocab_size + len (actual_ids ) - 1
346
+ raise Exception (f"Expected the { len (actual_ids )} added token ID(s) to be sequential in the range { vocab_size } - { expected_end_id } ; got { actual_ids } " )
335
347
336
348
items = sorted (added_tokens .items (), key = lambda text_idx : text_idx [1 ])
337
349
self .added_tokens_list = [text for (text , idx ) in items ]
@@ -345,10 +357,22 @@ def bpe_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
345
357
from transformers .models .gpt2 import tokenization_gpt2 # type: ignore[import]
346
358
byte_encoder = tokenization_gpt2 .bytes_to_unicode ()
347
359
byte_decoder = {v : k for k , v in byte_encoder .items ()}
360
+ score = 0.0
348
361
for i , item in enumerate (tokenizer ):
349
362
text : bytes = item .encode ("utf-8" )
350
- score : float = - i
351
- yield text , score , gguf .TokenType .USER_DEFINED
363
+ # FIXME: These shouldn't be hardcoded, but it's probably better than the current behavior?
364
+ if i <= 258 and text .startswith (b'<' ) and text .endswith (b'>' ):
365
+ if i == 0 and text == b'<unk>' :
366
+ toktype = gguf .TokenType .UNKNOWN
367
+ elif i == 1 or i == 2 :
368
+ toktype = gguf .TokenType .CONTROL
369
+ elif i >= 3 and text .startswith (b'<0x' ):
370
+ toktype = gguf .TokenType .BYTE
371
+ else :
372
+ toktype = gguf .TokenType .NORMAL
373
+ else :
374
+ toktype = gguf .TokenType .NORMAL
375
+ yield text , score , toktype
352
376
353
377
def added_tokens (self ) -> Iterable [tuple [bytes , float , gguf .TokenType ]]:
354
378
for text in self .added_tokens_list :
0 commit comments