| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| import random |
| from collections import deque |
| import heapq |
|
|
| |
| class QNetwork(nn.Module): |
| def __init__(self, state_size, action_size, hidden_sizes=[256, 128, 64], dropout_rate=0.1): |
| super(QNetwork, self).__init__() |
| self.state_size = state_size |
| self.action_size = action_size |
| |
| |
| layers = [] |
| prev_size = state_size |
| for hidden_size in hidden_sizes: |
| layers.append(nn.Linear(prev_size, hidden_size)) |
| layers.append(nn.BatchNorm1d(hidden_size)) |
| layers.append(nn.ReLU()) |
| layers.append(nn.Dropout(dropout_rate)) |
| prev_size = hidden_size |
| layers.append(nn.Linear(prev_size, action_size)) |
| |
| self.network = nn.Sequential(*layers) |
|
|
| def forward(self, x): |
| return self.network(x) |
|
|
| |
| class PriorityReplayMemory: |
| def __init__(self, capacity, alpha=0.6, beta=0.4, beta_increment=0.001): |
| self.capacity = capacity |
| self.alpha = alpha |
| self.beta = beta |
| self.beta_increment = beta_increment |
| self.memory = [] |
| self.experiences = deque(maxlen=capacity) |
| self.max_priority = 1.0 |
|
|
| def add(self, experience, error=None): |
| priority = error if error is not None else self.max_priority |
| priority = (abs(priority) + 1e-5) ** self.alpha |
| heapq.heappush(self.memory, (-priority, len(self.experiences))) |
| self.experiences.append(experience) |
|
|
| def sample(self, batch_size): |
| if len(self.experiences) < batch_size: |
| return None, None, None |
| |
| |
| priorities = np.array([-p for p, _ in self.memory[:len(self.experiences)]]) |
| probs = priorities / priorities.sum() |
| |
| |
| indices = np.random.choice(len(self.experiences), batch_size, p=probs, replace=False) |
| samples = [self.experiences[idx] for idx in indices] |
| |
| |
| weights = (len(self.experiences) * probs[indices]) ** (-self.beta) |
| weights /= weights.max() |
| |
| self.beta = min(1.0, self.beta + self.beta_increment) |
| return samples, indices, torch.FloatTensor(weights) |
|
|
| def update_priorities(self, indices, errors): |
| for idx, error in zip(indices, errors): |
| priority = (abs(error) + 1e-5) ** self.alpha |
| self.memory[idx] = (-priority, self.memory[idx][1]) |
| self.max_priority = max(self.max_priority, priority) |
| heapq.heapify(self.memory) |
|
|
| def __len__(self): |
| return len(self.experiences) |
|
|
| |
| class RLAgent: |
| def __init__(self, state_size, action_size, |
| lr=0.0005, gamma=0.99, epsilon=1.0, epsilon_decay=0.995, min_epsilon=0.01, |
| memory_capacity=10000, batch_size=64, target_update_freq=1000, |
| use_double_dqn=True, clip_grad_norm=1.0): |
| self.state_size = state_size |
| self.action_size = action_size |
| self.gamma = gamma |
| self.epsilon = epsilon |
| self.epsilon_decay = epsilon_decay |
| self.min_epsilon = min_epsilon |
| self.batch_size = batch_size |
| self.use_double_dqn = use_double_dqn |
| self.clip_grad_norm = clip_grad_norm |
| |
| |
| self.policy_net = QNetwork(state_size, action_size) |
| self.target_net = QNetwork(state_size, action_size) |
| self.target_net.load_state_dict(self.policy_net.state_dict()) |
| self.target_net.eval() |
| self.optimizer = optim.Adam(self.policy_net.parameters(), lr=lr) |
| self.criterion = nn.SmoothL1Loss() |
| |
| |
| self.memory = PriorityReplayMemory(memory_capacity) |
| self.steps = 0 |
| self.target_update_freq = target_update_freq |
|
|
| def remember(self, state, action, reward, next_state, done): |
| |
| state_tensor = torch.FloatTensor(state).unsqueeze(0) |
| next_state_tensor = torch.FloatTensor(next_state).unsqueeze(0) |
| with torch.no_grad(): |
| q_value = self.policy_net(state_tensor)[action] |
| next_q = self.target_net(next_state_tensor).max().item() |
| target = reward + (1 - done) * self.gamma * next_q |
| error = abs(q_value.item() - target) |
| self.memory.add((state, action, reward, next_state, done), error) |
|
|
| def act(self, state): |
| self.steps += 1 |
| if random.random() < self.epsilon: |
| return random.randint(0, self.action_size - 1) |
| state = torch.FloatTensor(state).unsqueeze(0) |
| with torch.no_grad(): |
| return torch.argmax(self.policy_net(state)).item() |
|
|
| def train(self): |
| if len(self.memory) < self.batch_size: |
| return |
| |
| |
| batch, indices, weights = self.memory.sample(self.batch_size) |
| if batch is None: |
| return |
| |
| states, actions, rewards, next_states, dones = zip(*batch) |
| states = torch.FloatTensor(states) |
| actions = torch.LongTensor(actions) |
| rewards = torch.FloatTensor(rewards) |
| next_states = torch.FloatTensor(next_states) |
| dones = torch.FloatTensor(dones) |
| weights = weights.unsqueeze(1) |
| |
| |
| q_values = self.policy_net(states).gather(1, actions.unsqueeze(1)).squeeze(1) |
| |
| |
| if self.use_double_dqn: |
| next_actions = self.policy_net(next_states).argmax(1) |
| next_q_values = self.target_net(next_states).gather(1, next_actions.unsqueeze(1)).squeeze(1) |
| else: |
| next_q_values = self.target_net(next_states).max(1)[0] |
| |
| |
| targets = rewards + self.gamma * next_q_values * (1 - dones) |
| |
| |
| td_errors = (q_values - targets).detach().cpu().numpy() |
| |
| |
| loss = (self.criterion(q_values, targets) * weights).mean() |
| |
| |
| self.optimizer.zero_grad() |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), self.clip_grad_norm) |
| self.optimizer.step() |
| |
| |
| self.memory.update_priorities(indices, td_errors) |
| |
| |
| if self.steps % self.target_update_freq == 0: |
| self.target_net.load_state_dict(self.policy_net.state_dict()) |
| |
| |
| self.epsilon = max(self.min_epsilon, self.epsilon * self.epsilon_decay) |