| """
|
| Hugging Face Export for Repo-JEPA
|
|
|
| This file enables loading Repo-JEPA with AutoModel.from_pretrained()
|
| using trust_remote_code=True.
|
| """
|
|
|
| import copy
|
| from typing import Optional, Tuple
|
|
|
| import torch
|
| import torch.nn as nn
|
| from transformers import PretrainedConfig, PreTrainedModel, RobertaModel
|
|
|
|
|
| class RepoJEPAConfig(PretrainedConfig):
|
| """Configuration for Repo-JEPA model."""
|
|
|
| model_type = "repo-jepa"
|
|
|
| def __init__(
|
| self,
|
| hidden_dim: int = 768,
|
| num_encoder_layers: int = 12,
|
| num_attention_heads: int = 12,
|
| intermediate_dim: int = 3072,
|
| hidden_dropout_prob: float = 0.1,
|
| attention_dropout_prob: float = 0.1,
|
| vocab_size: int = 50265,
|
| max_seq_len: int = 512,
|
| pad_token_id: int = 1,
|
| base_model: str = "microsoft/codebert-base",
|
| **kwargs,
|
| ):
|
| super().__init__(**kwargs)
|
| self.hidden_dim = hidden_dim
|
| self.num_encoder_layers = num_encoder_layers
|
| self.num_attention_heads = num_attention_heads
|
| self.intermediate_dim = intermediate_dim
|
| self.hidden_dropout_prob = hidden_dropout_prob
|
| self.attention_dropout_prob = attention_dropout_prob
|
| self.vocab_size = vocab_size
|
| self.max_seq_len = max_seq_len
|
| self.pad_token_id = pad_token_id
|
| self.base_model = base_model
|
|
|
|
|
| class ProjectionHead(nn.Module):
|
| """MLP projection head."""
|
|
|
| def __init__(self, input_dim: int, output_dim: int):
|
| super().__init__()
|
| self.layers = nn.Sequential(
|
| nn.Linear(input_dim, output_dim),
|
| nn.BatchNorm1d(output_dim),
|
| nn.ReLU(inplace=True),
|
| nn.Linear(output_dim, output_dim),
|
| nn.BatchNorm1d(output_dim),
|
| nn.ReLU(inplace=True),
|
| nn.Linear(output_dim, output_dim),
|
| )
|
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| return self.layers(x)
|
|
|
|
|
| class RepoJEPAModel(PreTrainedModel):
|
| """
|
| Repo-JEPA: Joint Embedding Predictive Architecture for Code Search.
|
|
|
| Use for semantic code search (encode_code) and retrieval queries (encode_query).
|
| """
|
|
|
| config_class = RepoJEPAConfig
|
|
|
| def __init__(self, config: RepoJEPAConfig):
|
| super().__init__(config)
|
|
|
|
|
| self.context_encoder = RobertaModel.from_pretrained(
|
| config.base_model,
|
| add_pooling_layer=False,
|
| )
|
| self.target_encoder = RobertaModel.from_pretrained(
|
| config.base_model,
|
| add_pooling_layer=False,
|
| )
|
|
|
|
|
| hidden_size = self.context_encoder.config.hidden_size
|
| self.context_projector = ProjectionHead(hidden_size, config.hidden_dim)
|
| self.target_projector = ProjectionHead(hidden_size, config.hidden_dim)
|
|
|
| self.post_init()
|
|
|
| def encode_code(
|
| self,
|
| input_ids: torch.Tensor,
|
| attention_mask: Optional[torch.Tensor] = None,
|
| ) -> torch.Tensor:
|
| """Encode code snippet into embedding space."""
|
| outputs = self.context_encoder(input_ids=input_ids, attention_mask=attention_mask)
|
| pooled = self._mean_pool(outputs.last_hidden_state, attention_mask)
|
| return self.context_projector(pooled)
|
|
|
| def encode_query(
|
| self,
|
| input_ids: torch.Tensor,
|
| attention_mask: Optional[torch.Tensor] = None,
|
| ) -> torch.Tensor:
|
| """Encode search query (docstring) into embedding space."""
|
| outputs = self.target_encoder(input_ids=input_ids, attention_mask=attention_mask)
|
| pooled = self._mean_pool(outputs.last_hidden_state, attention_mask)
|
| return self.target_projector(pooled)
|
|
|
| def _mean_pool(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
|
| if attention_mask is not None:
|
| mask = attention_mask.unsqueeze(-1).expand(hidden_states.size()).float()
|
| sum_hidden = torch.sum(hidden_states * mask, dim=1)
|
| sum_mask = torch.clamp(mask.sum(dim=1), min=1e-9)
|
| return sum_hidden / sum_mask
|
| return hidden_states.mean(dim=1)
|
|
|
| def forward(self, **kwargs):
|
|
|
| if "input_ids" in kwargs:
|
| return self.encode_code(kwargs["input_ids"], kwargs.get("attention_mask"))
|
| raise NotImplementedError("Use .encode_code() or .encode_query() specifically.")
|
|
|
|
|
|
|
| try:
|
| from transformers import AutoConfig, AutoModel
|
| AutoConfig.register("repo-jepa", RepoJEPAConfig)
|
| AutoModel.register(RepoJEPAConfig, RepoJEPAModel)
|
| except:
|
| pass
|
|
|
|
|