eliphatfs commited on
Commit ·
24cb86c
1
Parent(s): 654bd81
Avoid redundant model loading.
Browse files
openshape/demo/caption.py
CHANGED
|
@@ -2,7 +2,7 @@ from torch import nn
|
|
| 2 |
import numpy as np
|
| 3 |
import torch
|
| 4 |
from typing import Tuple, List, Union, Optional
|
| 5 |
-
from transformers import GPT2Tokenizer, GPT2LMHeadModel
|
| 6 |
from huggingface_hub import hf_hub_download
|
| 7 |
|
| 8 |
|
|
@@ -60,7 +60,7 @@ class ClipCaptionModel(nn.Module):
|
|
| 60 |
def __init__(self, prefix_length: int, prefix_size: int = 512):
|
| 61 |
super(ClipCaptionModel, self).__init__()
|
| 62 |
self.prefix_length = prefix_length
|
| 63 |
-
self.gpt = GPT2LMHeadModel.from_pretrained('gpt2')
|
| 64 |
self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1]
|
| 65 |
if prefix_length > 10: # not enough memory
|
| 66 |
self.clip_project = nn.Linear(prefix_size, self.gpt_embedding_size * prefix_length)
|
|
|
|
| 2 |
import numpy as np
|
| 3 |
import torch
|
| 4 |
from typing import Tuple, List, Union, Optional
|
| 5 |
+
from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPT2Config
|
| 6 |
from huggingface_hub import hf_hub_download
|
| 7 |
|
| 8 |
|
|
|
|
| 60 |
def __init__(self, prefix_length: int, prefix_size: int = 512):
|
| 61 |
super(ClipCaptionModel, self).__init__()
|
| 62 |
self.prefix_length = prefix_length
|
| 63 |
+
self.gpt = GPT2LMHeadModel(GPT2Config.from_pretrained('gpt2'))
|
| 64 |
self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1]
|
| 65 |
if prefix_length > 10: # not enough memory
|
| 66 |
self.clip_project = nn.Linear(prefix_size, self.gpt_embedding_size * prefix_length)
|