import gym
import torch
import torch.nn as nn
import torch.optim as optim
import random
import pygame
import sys
from collections import deque
# 定义DQN模型
class DQN(nn.Module):
def __init__(self):
super(DQN, self).__init__()
self.network = nn.Sequential(
nn.Linear(4, 128),
nn.ReLU(),
nn.Linear(128, 2) # 2个动作
)
def forward(self, x):
return self.network(x)
# 经验回放
class ReplayBuffer:
def __init__(self, capacity):
self.buffer = deque(maxlen=capacity)
def push(self, state, action, reward, next_state, done):
self.buffer.append((state, action, reward, next_state, done))
def sample(self, batch_size):
batch = random.sample(self.buffer, batch_size)
state, action, reward, next_state, done = zip(*batch)
return state, action, reward, next_state, done
def __len__(self):
return len(self.buffer)
# 训练函数
def optimize_model():
if len(memory) < BATCH_SIZE:
return
states, actions, rewards, next_states, dones = memory.sample(BATCH_SIZE)
states = torch.tensor(states, dtype=torch.float)
next_states = torch.tensor(next_states, dtype=torch.float)
actions = torch.tensor(actions, dtype=torch.long)
rewards = torch.tensor(rewards, dtype=torch.float)
dones = torch.tensor(dones, dtype=torch.float)
current_q_values = model(states).gather(1, actions.unsqueeze(1)).squeeze(1)
next_q_values = model(next_states).max(1)[0].detach()
expected_q_values = rewards + 0.99 * next_q_values * (1 - dones)
loss = criterion(current_q_values, expected_q_values)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 设置环境和模型
env = gym.make('CartPole-v1')
model = DQN()
memory = ReplayBuffer(10000)
optimizer = optim.Adam(model.parameters())
criterion = nn.MSELoss()
BATCH_SIZE = 128
EPSILON = 0.2
pygame.init()
screen = pygame.display.set_mode((600, 400))
clock = pygame.time.Clock()
# 开始训练
num_episodes = 500
for episode in range(num_episodes):
state = env.reset()
total_reward = 0
done = False
state = state[0]
while not done:
if random.random() < EPSILON:
action = env.action_space.sample()
else:
state_tensor = torch.tensor(state, dtype=torch.float).unsqueeze(0)
action = model(state_tensor).max(1)[1].item()
next_state, reward, done, _,_ = env.step(action)
memory.push(state, action, reward, next_state, done)
state = next_state
total_reward += reward
optimize_model()
# Pygame visualization
for event in pygame.event.get():
if event.type == pygame.QUIT:
pygame.quit()
sys.exit()
screen.fill((255, 255, 255))
cart_x = int(state[0] * 100 + 300)
pygame.draw.rect(screen, (0, 0, 255), (cart_x, 300, 50, 30))
pygame.draw.line(screen, (255, 0, 0), (cart_x + 25, 300), (cart_x + 25 - int(50 * torch.sin(torch.tensor(state[2]))), 300 - int(50 * torch.cos(torch.tensor(state[2])))), 5)
pygame.display.flip()
clock.tick(60)
EPSILON *= 0.995 # 减少探索率
print(f'Episode {episode}: Total Reward = {total_reward}')
if __name__ == '__main__':
main()
多思考也是一种努力,做出正确的分析和选择,因为我们的时间和精力都有限,所以把时间花在更有价值的地方。