latexOCR / configuration_latex_decoder.py
harryrobert's picture
Upload configuration_latex_decoder.py with huggingface_hub
e9daea4 verified
from transformers import PretrainedConfig
class LaTeXDecoderConfig(PretrainedConfig):
model_type = "latex_decoder"
def __init__(
self,
vocab_size: int = 8192,
pad_id: int = 0,
bos_id: int = 2,
eos_id: int = 3,
d_model: int = 512,
n_heads: int = 8,
n_layers: int = 6,
d_ff: int = 1408,
dropout: float = 0.1,
max_seq_len: int = 200,
rope_theta: float = 10000.0,
tie_weights: bool = True,
**kwargs,
):
kwargs.pop("pad_token_id", None)
kwargs.pop("bos_token_id", None)
kwargs.pop("eos_token_id", None)
super().__init__(
pad_token_id=pad_id,
bos_token_id=bos_id,
eos_token_id=eos_id,
**kwargs,
)
self.vocab_size = vocab_size
self.pad_id = pad_id
self.bos_id = bos_id
self.eos_id = eos_id
self.d_model = d_model
self.n_heads = n_heads
self.n_layers = n_layers
self.d_ff = d_ff
self.dropout = dropout
self.max_seq_len = max_seq_len
self.rope_theta = rope_theta
self.tie_weights = tie_weights
@property
def head_dim(self) -> int:
assert self.d_model % self.n_heads == 0
return self.d_model // self.n_heads