File size: 5,627 Bytes
38b40d0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 | import gzip
import html
import os
from functools import lru_cache
import ftfy
import numpy as np
import regex as re
from PIL import Image
@lru_cache()
def bytes_to_unicode():
bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"),
ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8 + n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
def get_pairs(word):
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
def basic_clean(text):
text = ftfy.fix_text(text)
text = html.unescape(html.unescape(text))
return text.strip()
def whitespace_clean(text):
text = re.sub(r"\s+", " ", text)
return text.strip()
class SimpleTokenizer(object):
def __init__(self, bpe_path: str):
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
merges = gzip.open(bpe_path, "rb").read().decode("utf-8").split("\n")
merges = merges[1:49152 - 256 - 2 + 1]
merges = [tuple(merge.split()) for merge in merges]
vocab = list(bytes_to_unicode().values())
vocab = vocab + [v + "</w>" for v in vocab]
for merge in merges:
vocab.append("".join(merge))
vocab.extend(["<|startoftext|>", "<|endoftext|>"])
self.encoder = dict(zip(vocab, range(len(vocab))))
self.decoder = {v: k for k, v in self.encoder.items()}
self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {"<|startoftext|>": "<|startoftext|>",
"<|endoftext|>": "<|endoftext|>"}
self.pat = re.compile(
r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
re.IGNORECASE,
)
def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token[:-1]) + (token[-1] + "</w>",)
pairs = get_pairs(word)
if not pairs:
return token + "</w>"
while True:
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(
pair, float("inf")))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except ValueError:
new_word.extend(word[i:])
break
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
new_word.append(first + second)
i += 2
else:
new_word.append(word[i])
i += 1
word = tuple(new_word)
if len(word) == 1:
break
pairs = get_pairs(word)
word = " ".join(word)
self.cache[token] = word
return word
def encode(self, text):
bpe_tokens = []
text = whitespace_clean(basic_clean(text)).lower()
for token in re.findall(self.pat, text):
token = "".join(self.byte_encoder[b]
for b in token.encode("utf-8"))
bpe_tokens.extend(self.encoder[bpe_token]
for bpe_token in self.bpe(token).split(" "))
return bpe_tokens
@lru_cache()
def default_bpe_path():
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
def tokenize(texts, context_length=77, truncate=False):
if isinstance(texts, str):
texts = [texts]
tokenizer = SimpleTokenizer(default_bpe_path())
sot_token = tokenizer.encoder["<|startoftext|>"]
eot_token = tokenizer.encoder["<|endoftext|>"]
all_tokens = [[sot_token] +
tokenizer.encode(text) + [eot_token] for text in texts]
result = np.zeros((len(all_tokens), context_length), dtype=np.int64)
for i, tokens in enumerate(all_tokens):
if len(tokens) > context_length:
if truncate:
tokens = tokens[:context_length]
tokens[-1] = eot_token
else:
raise RuntimeError(
f"Input {texts[i]} is too long for context length {context_length}")
result[i, :len(tokens)] = np.array(tokens, dtype=np.int64)
return result
def preprocess(image: Image.Image, image_resolution: int = 224) -> np.ndarray:
image = image.convert("RGB")
width, height = image.size
size = image_resolution
if width < height:
new_width = size
new_height = int(round(size * height / width))
else:
new_height = size
new_width = int(round(size * width / height))
image = image.resize((new_width, new_height), Image.BICUBIC)
left = (new_width - size) // 2
top = (new_height - size) // 2
image = image.crop((left, top, left + size, top + size))
arr = np.array(image).astype(np.float32) / 255.0
mean = np.array([0.48145466, 0.4578275, 0.40821073], dtype=np.float32)
std = np.array([0.26862954, 0.26130258, 0.27577711], dtype=np.float32)
arr = (arr - mean) / std
arr = arr.transpose(2, 0, 1)
return arr[np.newaxis, ...].astype(np.float32)
|