dineth554 commited on
Commit
390dadd
·
verified ·
1 Parent(s): 97cda6b

Upload sagemaker_inference.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. sagemaker_inference.py +181 -0
sagemaker_inference.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SageMaker Inference Script for Legion Coder 8M
3
+
4
+ This script handles model loading and inference for Amazon SageMaker deployment.
5
+ It follows the SageMaker inference container contract.
6
+ """
7
+
8
+ import os
9
+ import json
10
+ import torch
11
+ import sys
12
+ from pathlib import Path
13
+
14
+ # Add model code to path
15
+ sys.path.append('/opt/ml/model/code')
16
+
17
+ class LegionCoderModel(torch.nn.Module):
18
+ """Simplified model class for inference."""
19
+
20
+ def __init__(self, vocab_size=16000, d_model=576, num_layers=13, num_heads=16, d_ff=1152, max_seq_len=1024, dropout=0.1, pad_token_id=0):
21
+ super().__init__()
22
+ self.vocab_size = vocab_size
23
+ self.d_model = d_model
24
+ self.max_seq_len = max_seq_len
25
+ self.pad_token_id = pad_token_id
26
+ self.token_embedding = torch.nn.Embedding(vocab_size, d_model)
27
+ self.position_embedding = torch.nn.Embedding(max_seq_len, d_model)
28
+ self.blocks = torch.nn.ModuleList([self._create_block(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
29
+ self.norm = torch.nn.LayerNorm(d_model)
30
+ self.lm_head = torch.nn.Linear(d_model, vocab_size, bias=False)
31
+ self.lm_head.weight = self.token_embedding.weight
32
+ self.dropout = torch.nn.Dropout(dropout)
33
+
34
+ def _create_block(self, d_model, num_heads, d_ff, dropout):
35
+ """Create a transformer block."""
36
+ from model import TransformerBlock
37
+ return TransformerBlock(d_model, num_heads, d_ff, dropout)
38
+
39
+ def forward(self, input_ids, attention_mask=None, labels=None):
40
+ batch_size, seq_len = input_ids.shape
41
+ device = input_ids.device
42
+ positions = torch.arange(0, seq_len, device=device).unsqueeze(0).expand(batch_size, -1)
43
+ token_embeds = self.token_embedding(input_ids)
44
+ pos_embeds = self.position_embedding(positions)
45
+ x = self.dropout(token_embeds + pos_embeds)
46
+
47
+ # Create causal mask
48
+ mask = torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1)
49
+ causal_mask = mask == 0
50
+
51
+ if attention_mask is not None:
52
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
53
+ causal_mask = causal_mask.unsqueeze(0).unsqueeze(0) & attention_mask
54
+
55
+ for block in self.blocks:
56
+ x = block(x, causal_mask)
57
+
58
+ x = self.norm(x)
59
+ logits = self.lm_head(x)
60
+
61
+ loss = None
62
+ if labels is not None:
63
+ shift_logits = logits[..., :-1, :].contiguous()
64
+ shift_labels = labels[..., 1:].contiguous()
65
+ loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100)
66
+ loss = loss_fct(shift_logits.view(-1, self.vocab_size), shift_labels.view(-1))
67
+
68
+ return {'logits': logits, 'loss': loss}
69
+
70
+ def generate(self, input_ids, max_length=100, temperature=1.0, top_k=50, top_p=0.95, pad_token_id=0, eos_token_id=2):
71
+ self.eval()
72
+ batch_size = input_ids.shape[0]
73
+ device = input_ids.device
74
+
75
+ with torch.no_grad():
76
+ for _ in range(max_length):
77
+ if input_ids.shape[1] > self.max_seq_len:
78
+ input_ids = input_ids[:, -self.max_seq_len:]
79
+
80
+ outputs = self.forward(input_ids)
81
+ logits = outputs['logits']
82
+ next_token_logits = logits[:, -1, :] / temperature
83
+
84
+ if top_k > 0:
85
+ indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None]
86
+ next_token_logits[indices_to_remove] = float('-inf')
87
+
88
+ if top_p < 1.0:
89
+ sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
90
+ cumulative_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
91
+ sorted_indices_to_remove = cumulative_probs > top_p
92
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
93
+ sorted_indices_to_remove[..., 0] = 0
94
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
95
+ next_token_logits[indices_to_remove] = float('-inf')
96
+
97
+ probs = torch.nn.functional.softmax(next_token_logits, dim=-1)
98
+ next_token = torch.multinomial(probs, num_samples=1)
99
+ input_ids = torch.cat([input_ids, next_token], dim=1)
100
+
101
+ if (next_token == eos_token_id).all():
102
+ break
103
+
104
+ return input_ids
105
+
106
+
107
+ # SageMaker inference functions
108
+ def model_fn(model_dir):
109
+ """Load the model for inference."""
110
+ print(f"Loading model from {model_dir}")
111
+
112
+ # Load config
113
+ with open(os.path.join(model_dir, 'config.json'), 'r') as f:
114
+ config = json.load(f)
115
+
116
+ # Create model
117
+ model = LegionCoderModel(
118
+ vocab_size=config.get('vocab_size', 16000),
119
+ d_model=config.get('d_model', 576),
120
+ num_layers=config.get('num_layers', 13),
121
+ num_heads=config.get('num_heads', 16),
122
+ d_ff=config.get('d_ff', 1152),
123
+ max_seq_len=config.get('max_seq_len', 1024),
124
+ dropout=config.get('dropout', 0.1),
125
+ pad_token_id=config.get('pad_token_id', 0)
126
+ )
127
+
128
+ # Load weights
129
+ from safetensors.torch import load_file
130
+ state_dict = load_file(os.path.join(model_dir, 'model.safetensors'))
131
+ model.load_state_dict(state_dict, strict=False)
132
+ model.eval()
133
+
134
+ print("Model loaded successfully!")
135
+ return model
136
+
137
+
138
+ def input_fn(request_body, request_content_type):
139
+ """Parse input data."""
140
+ if request_content_type == 'application/json':
141
+ input_data = json.loads(request_body)
142
+ return input_data
143
+ else:
144
+ raise ValueError(f"Unsupported content type: {request_content_type}")
145
+
146
+
147
+ def predict_fn(input_data, model):
148
+ """Make prediction."""
149
+ import torch
150
+
151
+ # Get input text
152
+ text = input_data.get('inputs', '')
153
+ parameters = input_data.get('parameters', {})
154
+
155
+ # Default parameters
156
+ max_length = parameters.get('max_length', 100)
157
+ temperature = parameters.get('temperature', 0.8)
158
+ top_k = parameters.get('top_k', 50)
159
+ top_p = parameters.get('top_p', 0.95)
160
+
161
+ # Tokenize (simplified - would use actual tokenizer in production)
162
+ # For now, return a placeholder
163
+ return {
164
+ 'generated_text': f"Generated response for: {text[:50]}...",
165
+ 'parameters': parameters
166
+ }
167
+
168
+
169
+ def output_fn(prediction, response_content_type):
170
+ """Format output."""
171
+ if response_content_type == 'application/json':
172
+ return json.dumps(prediction), response_content_type
173
+ else:
174
+ raise ValueError(f"Unsupported content type: {response_content_type}")
175
+
176
+
177
+ if __name__ == "__main__":
178
+ # Test local inference
179
+ print("Testing SageMaker inference script...")
180
+ print("This script is designed to run within a SageMaker container.")
181
+ print("For local testing, use the Streamlit app or direct model loading.")