Raziel1234 commited on
Commit
8c7686d
ยท
verified ยท
1 Parent(s): 303ab82

Update modeling_pixel.py

Browse files
Files changed (1) hide show
  1. modeling_pixel.py +8 -6
modeling_pixel.py CHANGED
@@ -19,18 +19,20 @@ class ResidualBlock(nn.Module):
19
 
20
  class TopAIImageGenerator(PreTrainedModel):
21
  config_class = TopAIImageConfig
22
- # ืฉื™ืžื•ืฉ ื‘ืจืฉื™ืžื” ืจื™ืงื” ื›ืกื˜ื ื“ืจื˜ ืฉืœ Transformers
23
- all_tied_weights_keys = []
 
24
 
25
  def __init__(self, config):
26
  super().__init__(config)
27
- h = config.hidden_dim # ืžืงื‘ืœ 512 ืžื”-config.json
 
28
 
29
  self.text_projection = nn.Linear(config.input_dim, 4 * 4 * h)
30
 
31
- # ื‘ื ื™ื™ื” ื“ื™ื ืžื™ืช ืฉืžืชืื™ืžื” ืœืžืฉืงื•ืœื•ืช ืฉืื•ืžื ื• ื‘-512
32
  self.decoder = nn.Sequential(
33
- # ืฉื›ื‘ื” 0: ืž-512 ืœ-512 (ื›ืืŸ ื”ื™ื™ืชื” ื”ืฉื’ื™ืื”!)
34
  self._upsample(h, h),
35
  # ืฉื›ื‘ื” 1
36
  ResidualBlock(h),
@@ -59,7 +61,7 @@ class TopAIImageGenerator(PreTrainedModel):
59
  )
60
 
61
  def forward(self, text_embeddings):
62
- # ื”ืžืจืช ื”-Embedding ืœืžืคื” ืฉืœ 4x4
63
  x = self.text_projection(text_embeddings)
64
  x = x.view(-1, self.config.hidden_dim, 4, 4)
65
  return self.decoder(x)
 
19
 
20
  class TopAIImageGenerator(PreTrainedModel):
21
  config_class = TopAIImageConfig
22
+
23
+ # ืชื™ืงื•ืŸ ื”-AttributeError: ื—ื™ื™ื‘ ืœื”ื™ื•ืช ืžื™ืœื•ืŸ (dict) ื›ื“ื™ ืฉืชื”ื™ื” ืœื• ืžืชื•ื“ื” .keys()
24
+ all_tied_weights_keys = {}
25
 
26
  def __init__(self, config):
27
  super().__init__(config)
28
+ # ืฉื™ืžื•ืฉ ื‘-hidden_dim ืžื”ืงื•ื ืคื™ื’ (512)
29
+ h = config.hidden_dim
30
 
31
  self.text_projection = nn.Linear(config.input_dim, 4 * 4 * h)
32
 
33
+ # ื‘ื ื™ื™ื” ื“ื™ื ืžื™ืช ืฉืžืชืื™ืžื” ื‘ื“ื™ื•ืง ืœืžืฉืงื•ืœื•ืช ื‘-Safetensors
34
  self.decoder = nn.Sequential(
35
+ # ืฉื›ื‘ื” 0: ืž-512 ืœ-512 (ื›ืืŸ ื”ื™ื” ื”-Mismatch)
36
  self._upsample(h, h),
37
  # ืฉื›ื‘ื” 1
38
  ResidualBlock(h),
 
61
  )
62
 
63
  def forward(self, text_embeddings):
64
+ # ืฉื™ื ื•ื™ ืฆื•ืจื” ืœืžืคืช ืžืืคื™ื™ื ื™ื ืจืืฉื•ื ื™ืช
65
  x = self.text_projection(text_embeddings)
66
  x = x.view(-1, self.config.hidden_dim, 4, 4)
67
  return self.decoder(x)