| import math |
| import numpy as np |
| import torch |
| |
| class MultiheadAttention(torch.nn.Module): |
| def __init__(self, d_model, n_head, n_token = 77, dropout = 0.1): |
| super().__init__() |
| self.d_model = d_model |
| self.n_head = n_head |
| self.d_head = d_model // n_head |
| self.n_token = n_token |
| |
| self.query = torch.nn.Linear(d_model, d_model) |
| self.key = torch.nn.Linear(d_model, d_model) |
| self.value = torch.nn.Linear(d_model, d_model) |
| self.proj = torch.nn.Linear(d_model, d_model) |
| |
| self.div = torch.sqrt(torch.tensor(self.d_head, dtype = self.query.weight.dtype)) |
| |
| self.softmax = torch.nn.Softmax(dim = -1) |
| self.dropout = torch.nn.Dropout(dropout) |
| |
| self._reset_parameters() |
|
|
| def _reset_parameters(self): |
| torch.nn.init.xavier_uniform_(self.query.weight) |
| torch.nn.init.xavier_uniform_(self.key.weight) |
| torch.nn.init.xavier_uniform_(self.value.weight) |
| torch.nn.init.xavier_uniform_(self.proj.weight) |
| |
| torch.nn.init.constant_(self.query.bias, 0.) |
| torch.nn.init.constant_(self.key.bias, 0.) |
| torch.nn.init.constant_(self.value.bias, 0.) |
| torch.nn.init.constant_(self.proj.bias, 0.) |
| |
| def forward(self, q, k, v, mask = None, weight = None, alpha = None): |
| b, s = q.shape[:2] |
| b2, s2 = k.shape[:2] |
| |
| q = self.query(q) |
| k = self.key(k) |
| v = self.value(v) |
| |
| q = q.view(-1, s, self.n_head, self.d_head).transpose(1, 2) |
| k = k.view(-1, s2, self.n_head, self.d_head).transpose(1, 2) |
| v = v.view(-1, s2, self.n_head, self.d_head).transpose(1, 2) |
| |
| score = torch.matmul(q, k.transpose(-2, -1)) / self.div |
|
|
| if mask is not None: |
| mask = mask.unsqueeze(1) |
| if mask.dim() != score.dim(): |
| mask = mask.unsqueeze(2) |
| score = score * mask |
| |
| if weight is not None: |
| weight = weight.unsqueeze(1) |
| if weight.dim() != score.dim(): |
| weight = weight.unsqueeze(2) |
| if self.n_token == s2: |
| w = self.softmax(score) |
| if weight is not None: |
| w = w * weight |
| w = w / (w.sum(dim = -1, keepdim = True) + 1e-12) |
| else: |
| target, ref = torch.split(score, [self.n_token, s2 - self.n_token], dim = -1) |
| target = self.softmax(target) |
| if alpha is None: |
| alpha = 0.5 |
| if weight is not None: |
| ws = weight.shape[-1] |
| target_weight, ref_weight = torch.split(weight, [self.n_token, ws - self.n_token], dim = -1) |
| ref = ref.view(b2, self.n_head, s, ws - self.n_token, self.n_token) |
| ref = self.softmax(ref) |
| ref = ref * ref_weight.unsqueeze(-1) |
| ref = ref.view(b2, self.n_head, s, s2 - self.n_token) |
| ref = alpha * (ref / (ref.sum(dim = -1, keepdim = True) + 1e-12)) |
| target = target * (1 - alpha) * target_weight |
| w = torch.cat([target, ref], dim = -1) |
| w = w / (w.sum(dim = -1, keepdim = True) + 1e-12) |
| w = self.dropout(w) |
| |
| out = torch.matmul(w, v) |
| out = out.transpose(1, 2).contiguous().view(b, s, self.d_model) |
| out = self.proj(out) |
| return out |
| |
| class QuickGELU(torch.nn.Module): |
| def forward(self, x): |
| return x * torch.sigmoid(1.702 * x) |
|
|
| class TransformerBlock(torch.nn.Module): |
| def __init__(self, emb_dim, n_head, ff_dim, n_token = 77, activation = "quick_gelu", dropout = 0.1): |
| super().__init__() |
| self.attn = MultiheadAttention(emb_dim, n_head, n_token = n_token, dropout = dropout) |
| if activation.lower() == "gelu" or activation is None: |
| self.act = torch.nn.GELU() |
| elif activation.lower() == "relu": |
| self.act = torch.nn.ReLU() |
| elif activation.lower() == "quick_gelu": |
| self.act = QuickGELU() |
| else: |
| self.act = activation |
| self.ff = torch.nn.Sequential( |
| torch.nn.Linear(emb_dim, ff_dim), |
| self.act, |
| torch.nn.Linear(ff_dim, emb_dim), |
| ) |
| self.norm1 = torch.nn.LayerNorm(emb_dim) |
| self.norm2 = torch.nn.LayerNorm(emb_dim) |
| self.dropout1 = torch.nn.Dropout(dropout) |
| self.dropout2 = torch.nn.Dropout(dropout) |
| |
| self._reset_parameters() |
|
|
| def _reset_parameters(self): |
| torch.nn.init.xavier_uniform_(self.ff[0].weight) |
| torch.nn.init.xavier_uniform_(self.ff[2].weight) |
| |
| torch.nn.init.constant_(self.ff[0].bias, 0.) |
| torch.nn.init.constant_(self.ff[2].bias, 0.) |
| |
| def forward(self, x, context = None, mask = None, weight = None, alpha = None): |
| context = context if context is not None else x |
| out = self.attn(x, context, context, mask = mask, weight = weight, alpha = alpha) |
| out = x + self.dropout1(out) |
| out = self.norm1(out) |
| |
| ff_out = self.ff(out) |
| out = out + self.dropout2(ff_out) |
| out = self.norm2(out) |
| return out |
|
|
| class PersonalizedAdapter(torch.nn.Module): |
| def __init__(self, emb_dim, n_head, ff_dim, n_layer = 4, n_token = 77, proj = False, extra_proj = False, pos = True, cls_pos = False, cls_token = True, encode_ratio = None, activation = "quick_gelu", dropout = 0.1): |
| super().__init__() |
| self.n_layer = n_layer |
| self.n_token = n_token |
| self.cls_pos = cls_pos |
| self.cls_token = cls_token |
| self.encode_ratio = encode_ratio |
| |
| self.pre_proj = self.post_proj = None |
| if encode_ratio and encode_ratio != 1: |
| self.pre_proj = torch.nn.Linear(emb_dim, int(emb_dim // encode_ratio)) |
| self.post_proj = torch.nn.Linear(int(emb_dim // encode_ratio), emb_dim) |
| emb_dim = int(emb_dim // encode_ratio) |
| n_head = int(n_head // encode_ratio) |
| |
| if activation.lower() == "gelu" or activation is None: |
| self.act = torch.nn.GELU() |
| elif activation.lower() == "relu": |
| self.act = torch.nn.ReLU() |
| elif activation.lower() == "quick_gelu": |
| self.act = QuickGELU() |
| else: |
| self.act = activation |
| self.base_query = torch.nn.Parameter(torch.empty(1, n_token + int(cls_token), emb_dim)) |
| self.pos = torch.nn.Parameter(torch.empty(1, n_token + int(cls_pos and cls_token), emb_dim)) if pos else None |
| self.init_query = None |
| |
| self.proj = None |
| if proj: |
| self.proj = torch.nn.Sequential( |
| torch.nn.Linear(emb_dim, ff_dim), |
| self.act, |
| torch.nn.Linear(ff_dim, emb_dim), |
| ) |
| |
| self.extra_proj = None |
| self.tf = torch.nn.ModuleList([TransformerBlock(emb_dim, n_head, ff_dim, n_token = n_token, activation = activation, dropout = dropout) for _ in range(n_layer)]) |
| if extra_proj: |
| self.extra_proj = torch.nn.ModuleList([torch.nn.Linear(emb_dim, emb_dim) for _ in range(n_layer)]) |
| |
| self._reset_parameters() |
|
|
| def _reset_parameters(self): |
| torch.nn.init.normal_(self.base_query, std = 0.02) |
| if self.pos is not None: |
| torch.nn.init.normal_(self.pos, std = 0.01) |
| |
| for proj in [self.pre_proj, self.post_proj]: |
| if proj is not None: |
| torch.nn.init.xavier_uniform_(proj.weight) |
| torch.nn.init.constant_(proj.bias, 0.) |
| for proj in [self.proj]: |
| if proj is not None: |
| torch.nn.init.xavier_uniform_(proj[0].weight) |
| torch.nn.init.xavier_uniform_(proj[2].weight) |
|
|
| torch.nn.init.constant_(proj[0].bias, 0.) |
| torch.nn.init.constant_(proj[2].bias, 0.) |
| if self.extra_proj is not None: |
| for l in self.extra_proj: |
| torch.nn.init.xavier_uniform_(l.weight) |
| torch.nn.init.constant_(l.bias, 0.) |
| |
| def set_base_query(self, x): |
| if not torch.is_tensor(x): |
| x = torch.tensor(x, dtype=self.base_query.dtype).to(self.base_query.device) |
| if x.dim() == 2: |
| x = x.unsqueeze(0) |
| self.init_query = x |
| |
| def normal_forward(self, x, context, mask = None, weight = None, alpha = None): |
| out = x |
| for i in range(self.n_layer): |
| if self.extra_proj is not None: |
| _context = self.extra_proj[i](self.act(context)) |
| else: |
| _context = context |
| out = self.tf[i](out, _context, mask = mask, weight = weight, alpha = alpha) |
| if self.cls_token: |
| return out[:, :-1], out[:, -1] |
| else: |
| return out, None |
| |
| def forward(self, context, mask = None, weight = None, alpha = None, base_query = None): |
| dtype = self.base_query.dtype |
| if base_query is not None: |
| x = base_query |
| else: |
| x = self.base_query if self.init_query is None else self.init_query |
| x = x.type(dtype) |
| if context is not None: |
| context = context.type(dtype) |
| if weight is not None: |
| weight = weight.type(dtype) |
| if self.encode_ratio is not None and x.shape[-1] != self.base_query.shape[-1]: |
| x = self.pre_proj(x) |
| if self.n_token < x.shape[1]: |
| x, cls = x[:, :self.n_token], x[:, self.n_token:] |
| else: |
| cls = self.base_query[:, self.n_token:] if self.cls_token else None |
| if self.pos is not None: |
| if self.cls_pos and self.cls_token: |
| x = x + self.pos[:, :self.n_token] |
| if cls is not None: |
| cls = cls + self.pos[:, self.n_token:] |
| else: |
| x = x + self.pos |
| if self.cls_token: |
| x = torch.cat([x, cls], dim = 1) |
| x = x.repeat_interleave(context.shape[0], dim = 0) |
| if self.encode_ratio is not None: |
| if context is not None: |
| context = self.pre_proj(context) |
| if self.proj is not None: |
| context = self.proj(context) |
| out = self.normal_forward(x, context, mask = mask, weight = weight, alpha = alpha) |
| if self.encode_ratio is not None: |
| out = (self.post_proj(out[0]), self.post_proj(out[1]) if out[1] is not None else out[1]) |
| return out |
|
|
| class DrUM: |
| def __init__(self, model, processor, n_layer = 8, proj = False, extra_proj = False, mlp_ratio = 4, pos = True, cls_pos = False, cls_token = True, encode_ratio = None, max_token_size = 256, activation = "quick_gelu", dropout = 0.1): |
| config = model.config.text_config if hasattr(model.config, "text_config") else model.config |
| if hasattr(config, "model_type") and config.model_type == "t5": |
| self.d_model = config.d_model |
| self.n_head = config.num_heads |
| self.n_token = min(processor.model_max_length, max_token_size) |
| self.clip = False |
| self.cls_token = False |
| else: |
| self.d_model = config.hidden_size |
| self.n_head = config.num_attention_heads |
| self.n_token = config.max_position_embeddings |
| self.clip = True |
| self.cls_token = cls_token |
| self.n_layer = n_layer |
| self.proj = proj |
| self.extra_proj = extra_proj |
| self.mlp_ratio = mlp_ratio |
| self.pos = pos |
| self.cls_pos = cls_pos |
| self.encode_ratio = encode_ratio |
| self.activation = activation |
| self.dropout = dropout |
| |
| self.model = model |
| self.processor = processor |
| self.adapter = PersonalizedAdapter(self.d_model, self.n_head, self.d_model // mlp_ratio, n_layer, self.n_token, proj = proj, extra_proj = extra_proj, pos = pos, cls_pos = cls_pos, cls_token = self.cls_token, encode_ratio = encode_ratio, activation = activation, dropout = dropout).to(model.device) |
| |
| self.train() |
| self.to(model.device) |
| |
| def preprocess(self, text = None, image = None, return_tensors = "pt", padding = "max_length", truncation = True, **kwargs): |
| feed = {"text":([text] if np.ndim(text) == 0 else list(text)) if text is not None else None, |
| "return_tensors":return_tensors, |
| "max_length":self.n_token, |
| "padding":padding, |
| "truncation":truncation, |
| **kwargs} |
| if not self.clip: |
| feed["add_special_tokens"] = True |
| if image is not None: |
| feed["images"] = image |
| return self.processor(**feed) |
| |
| def pool_text_hidden_state(self, hidden_state, x, padding = "max_length", truncation = True, **kwargs): |
| if not self.clip: |
| raise TypeError("T5 encoder does not support this function (pool_text_hidden_state).") |
| if not hasattr(x, "items"): |
| x = self.preprocess(text = x, padding = padding, truncation = truncation, **kwargs) |
| if self.model.text_model.eos_token_id == 2: |
| out = hidden_state[torch.arange(hidden_state.shape[0], device = hidden_state.device), |
| x["input_ids"].to(dtype = torch.int, device = hidden_state.device).argmax(dim = -1),] |
| else: |
| out = hidden_state[torch.arange(hidden_state.shape[0], device = hidden_state.device), |
| (x["input_ids"].to(dtype = torch.int, device = hidden_state.device) == self.model.text_model.eos_token_id).int().argmax(dim = -1),] |
| return out |
| |
| def normalize_text_hidden_state(self, hidden_state): |
| out = self.model.text_model.final_layer_norm(hidden_state.type(self.model.dtype)) if self.clip and hasattr(self.model.text_model, "final_layer_norm") else hidden_state |
| return out |
| |
| def projection_text_hidden_state(self, hidden_state): |
| out = self.model.text_projection(hidden_state.type(self.model.dtype)) if self.clip and hasattr(self.model, "text_projection") else hidden_state |
| return out |
| |
| def encode_prompt(self, x, pooling = True, skip = -1, skip_pool = None, padding = "max_length", truncation = True, use_attn_mask = False, normalize = True, normalize_pool = True, **kwargs): |
| if not hasattr(x, "items"): |
| x = self.preprocess(text = x, padding = padding, truncation = truncation, **kwargs) |
| input_ids = x["input_ids"].to(self.device) |
| attention_mask = x["attention_mask"].to(self.device) if use_attn_mask else None |
| with torch.no_grad(): |
| if self.clip: |
| hidden_state = self.model.text_model(output_hidden_states = True, input_ids = input_ids, attention_mask = attention_mask)["hidden_states"] |
| pool, hidden_state = hidden_state[skip_pool if skip_pool is not None else skip], hidden_state[skip] |
| hidden_state = self.normalize_text_hidden_state(hidden_state) if normalize else hidden_state |
| else: |
| hidden_state = self.model(input_ids = input_ids, attention_mask = attention_mask)[0] |
| pool = None |
| if pooling: |
| if self.clip: |
| with torch.no_grad(): |
| pool = self.pool_text_hidden_state(self.normalize_text_hidden_state(pool) if normalize_pool else pool, x, **kwargs) |
| return (hidden_state, pool) |
| return hidden_state |
| |
| def get_text_feature(self, x, ref_x = None, weight = None, alpha = 0.3, skip = -1, batch_size = 64, padding = "max_length", truncation = True, use_attn_mask = False, **kwargs): |
| if not self.clip: |
| raise TypeError("T5 encoder does not support this function (get_text_feature).") |
| with torch.no_grad(): |
| pool_hidden_state = self(x, ref_x, weight = weight, alpha = alpha, pooling = True, skip_pool = skip, batch_size = batch_size, padding = padding, truncation = truncation, use_attn_mask = use_attn_mask, normalize_pool = True, **kwargs)[1] |
| result = self.projection_text_hidden_state(pool_hidden_state) |
| return result |
| |
| def get_image_feature(self, x, return_tensors = "pt", **kwargs): |
| if not self.clip: |
| raise TypeError("T5 encoder does not support this function (get_image_feature).") |
| if hasattr(x, "items"): |
| x = x["pixel_values"] |
| elif not torch.is_tensor(x): |
| x = self.preprocess(image = x, return_tensors = return_tensors, **kwargs)["pixel_values"] |
| with torch.no_grad(): |
| result = self.model.get_image_features(pixel_values = x.to(self.device)) |
| return result |
| |
| def encode_context(self, ref_x, pooling = False, skip = -1, skip_pool = None, batch_size = 64, padding = "max_length", truncation = True, use_attn_mask = False, normalize = False, normalize_pool = False, **kwargs): |
| if not hasattr(ref_x, "items"): |
| if np.ndim(ref_x) == 0: |
| ref_x = [[ref_x]] |
| elif np.ndim(ref_x) == 1: |
| ref_x = [ref_x] |
| b, ref_size = len(ref_x), len(ref_x[0]) |
| ref_x = np.reshape(ref_x, [b * ref_size]) |
| ref_x = self.preprocess(text = list(ref_x), padding = padding, truncation = truncation, **kwargs) |
| ref_x = {k:v for k, v in ref_x.items() if k in (["input_ids", "attention_mask"] if use_attn_mask else ["input_ids"])} |
| else: |
| b, ref_size = ref_x["input_ids"].shape[:2] |
| ref_x = {k:v.view(b * ref_size, -1) for k, v in ref_x.items() if k in (["input_ids", "attention_mask"] if use_attn_mask else ["input_ids"])} |
| hidden_state, pool_hidden_state = [], [] |
| batch_indices = [(i * batch_size, min((b * ref_size), (i + 1) * batch_size)) for i in range(int(np.ceil((b * ref_size) / batch_size)))] |
| for start, end in batch_indices: |
| h, p = self.encode_prompt({k:v[start:end] for k, v in ref_x.items()}, pooling = True, skip = skip, skip_pool = skip_pool, padding = padding, truncation = truncation, use_attn_mask = use_attn_mask, normalize = normalize, normalize_pool = normalize_pool, **kwargs) |
| hidden_state.append(h) |
| if p is not None: |
| pool_hidden_state.append(p) |
| hidden_state = torch.cat(hidden_state, dim = 0) if 1 < len(hidden_state) else hidden_state[0] |
| pool_hidden_state = torch.cat(pool_hidden_state, dim = 0) if 1 < len(pool_hidden_state) else (pool_hidden_state[0] if len(pool_hidden_state) == 1 else None) |
| with torch.no_grad(): |
| hidden_state = hidden_state.view(b, ref_size * hidden_state.shape[1], -1) |
| if pooling: |
| if self.clip: |
| pool_hidden_state = pool_hidden_state.view(b, ref_size, -1) |
| hidden_state = (hidden_state, pool_hidden_state) |
| return hidden_state |
| |
| def __call__(self, x, ref_x = None, weight = None, alpha = 0.3, pooling = True, skip = -1, skip_pool = None, batch_size = 64, padding = "max_length", truncation = True, use_attn_mask = False, normalize = True, normalize_pool = True, training = False, **kwargs): |
| if ref_x is not None or training: |
| if training: |
| context = weight = None |
| else: |
| _context, _context_pool = self.encode_context(ref_x, pooling = True, skip = skip, skip_pool = None, batch_size = batch_size, padding = padding, truncation = truncation, use_attn_mask = use_attn_mask, normalize = False, normalize_pool = False, **kwargs) |
| if weight is not None: |
| if not torch.is_tensor(weight): |
| weight = torch.tensor(weight) |
| if weight.dim() == 0: |
| weight = weight.unsqueeze(0).unsqueeze(0) |
| elif weight.dim() == 1: |
| weight = weight.unsqueeze(0) |
| weight = weight.to(self.device) |
| else: |
| weight = torch.ones((1, _context.shape[1] // self.n_token), dtype = torch.float32, device = _context.device) |
| context = _context |
| del _context, _context_pool |
| result = self.encode_personalized_prompt(x, context, weight = weight, alpha = alpha, pooling = pooling, skip = skip, padding = padding, truncation = truncation, use_attn_mask = use_attn_mask, normalize = normalize, normalize_pool = normalize_pool, **kwargs) |
| return result |
| else: |
| return self.encode_prompt(x, pooling = pooling, skip = skip, skip_pool = skip_pool, padding = padding, truncation = truncation, use_attn_mask = use_attn_mask, normalize = normalize, normalize_pool = normalize_pool, **kwargs) |
| |
| def encode_personalized_prompt(self, x, context = None, weight = None, alpha = 0.3, pooling = True, skip = -1, padding = "max_length", truncation = True, use_attn_mask = False, normalize = True, normalize_pool = True, **kwargs): |
| if not torch.is_tensor(x): |
| if not hasattr(x, "items"): |
| x = self.preprocess(text = x, padding = padding, truncation = truncation, **kwargs) |
| x = self.encode_prompt(x, pooling = False, skip = skip, skip_pool = None, padding = padding, truncation = truncation, use_attn_mask = use_attn_mask, normalize = False, normalize_pool = False, **kwargs) |
| if context is None: |
| context = x |
| else: |
| batch_size, n_token = x.shape[:2] |
| if context.shape[0] == 1 and batch_size != 1: |
| context = context.repeat_interleave(batch_size, dim = 0) |
| if weight is not None and weight.shape[0] == 1: |
| weight = weight.repeat_interleave(batch_size, dim = 0) |
| context_size = context.shape[1] |
| context = torch.cat([x, context], dim = 1) |
| if weight is not None: |
| extra_weight = torch.ones((batch_size, n_token), dtype = torch.float32, device = weight.device) |
| weight = torch.cat([extra_weight, weight], dim = 1) |
| hidden_state, pool = self.adapter(context, weight = weight, alpha = alpha) |
| hidden_state = self.normalize_text_hidden_state(hidden_state) if normalize else hidden_state |
| if pooling: |
| pool = self.normalize_text_hidden_state(pool) if normalize_pool else pool |
| return (hidden_state, pool) |
| return hidden_state |
| |
| def to(self, device): |
| self.model.to(device) |
| self.adapter.to(device) |
| self.device = device |
| return self |
| |
| def eval(self): |
| self.model.eval() |
| if self.clip and hasattr(self.model, "text_projection"): |
| self.model.text_model.final_layer_norm.requires_grad_(False) |
| self.model.text_projection.requires_grad_(False) |
| self.adapter.eval() |
| return self |
| |
| def train(self): |
| self.model.eval() |
| if self.clip and hasattr(self.model, "text_projection"): |
| self.model.text_model.final_layer_norm.requires_grad_(False) |
| self.model.text_projection.requires_grad_(False) |
| self.adapter.train() |
| return self |
| |
| def parameters(self): |
| return list(self.adapter.parameters()) |
| |
| def named_parameters(self): |
| return list(self.adapter.named_parameters()) |