| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| from torch.utils.data import DataLoader, Dataset |
| import json |
| import os |
|
|
| |
| class CustomDataset(Dataset): |
| def __init__(self, texts, labels): |
| self.texts = texts |
| self.labels = labels |
|
|
| def __len__(self): |
| return len(self.texts) |
|
|
| def __getitem__(self, idx): |
| return self.texts[idx], self.labels[idx] |
|
|
| |
| class LSTMModel(nn.Module): |
| def __init__(self, input_size, hidden_size, output_size): |
| super(LSTMModel, self).__init__() |
| self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True) |
| self.fc = nn.Linear(hidden_size, output_size) |
|
|
| def forward(self, x): |
| lstm_out, _ = self.lstm(x) |
| out = self.fc(lstm_out[:, -1, :]) |
| return out |
|
|
| |
| input_size = 100 |
| hidden_size = 64 |
| output_size = 10 |
| num_epochs = 5 |
| learning_rate = 0.001 |
|
|
| |
| model = LSTMModel(input_size, hidden_size, output_size) |
|
|
| |
| criterion = nn.CrossEntropyLoss() |
| optimizer = optim.Adam(model.parameters(), lr=learning_rate) |
|
|
| |
| texts = torch.randn(100, 10, input_size) |
| labels = torch.randint(0, output_size, (100,)) |
|
|
| |
| dataset = CustomDataset(texts, labels) |
| data_loader = DataLoader(dataset, batch_size=16, shuffle=True) |
|
|
| |
| for epoch in range(num_epochs): |
| for inputs, targets in data_loader: |
| |
| outputs = model(inputs) |
| loss = criterion(outputs, targets) |
|
|
| |
| optimizer.zero_grad() |
| loss.backward() |
| optimizer.step() |
| |
| print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}') |
|
|
| |
| model_save_path = "model" |
| os.makedirs(model_save_path, exist_ok=True) |
|
|
| |
| torch.save(model.state_dict(), os.path.join(model_save_path, "pytorch_model.bin")) |
|
|
| |
| config = { |
| "input_size": input_size, |
| "hidden_size": hidden_size, |
| "output_size": output_size, |
| "num_layers": 1, |
| "dropout": 0.2 |
| } |
|
|
| |
| with open(os.path.join(model_save_path, "config.json"), "w") as f: |
| json.dump(config, f) |
|
|
| print("Model and configuration saved successfully!") |
|
|