import torch import torch.nn as nn import torch.optim as optim import numpy as np import matplotlib.pyplot as plt # 生成随机时序数据 def generate_time_series_data(num_samples, sequence_length): X = torch.randn(num_samples, sequence_length, 1) # Assuming a univariate time series y = (X.sum(dim=(1, 2)) + 0.1 * torch.randn(num_samples)).view(-1, 1) return X, y # 定义简单的LSTM模型 class SimpleLSTM(nn.Module): def __init__(self, input_size, hidden_size, output_size): super(SimpleLSTM, self).__init__() self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True) self.fc = nn.Linear(hidden_size, output_size) def forward(self, x): _, (h_n, _) = self.lstm(x) x = self.fc(h_n[-1]) return x # 计算 Fisher 信息 def calculate_fisher(model, dataloader, device): fisher_info = [] model.eval() criterion = nn.MSELoss() for inputs, labels in dataloader: inputs, labels = inputs.to(device), labels.to(device) # Forward pass outputs = model(inputs) # Compute loss and backward pass loss = criterion(outputs, labels) model.zero_grad() loss.backward() # Extract gradients from the model parameters gradients = [param.grad.flatten().detach().cpu().numpy() for param in model.parameters()] fisher_info.append(np.square(np.concatenate(gradients)) / len(dataloader.dataset)) fisher_info = np.mean(fisher_info, axis=0) return fisher_info # 定义 EWC 损失 def ewc_loss(fisher_information, weight, weight_old, lambda_): return lambda_ / 2 * torch.sum(torch.tensor(fisher_information) * (weight - weight_old) 2) # 初始化模型、数据和优化器 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = SimpleLSTM(input_size=1, hidden_size=64, output_size=1).to(device) optimizer = optim.SGD(model.parameters(), lr=0.01) criterion = nn.MSELoss() # 生成初始时序数据并进行初始训练 X_initial, y_initial = generate_time_series_data(100, 10) initial_dataset = torch.utils.data.TensorDataset(X_initial, y_initial) initial_dataloader = torch.utils.data.DataLoader(initial_dataset, batch_size=32, shuffle=True) # 初始训练 for epoch in range(50): for inputs, labels in initial_dataloader: inputs, labels = inputs.to(device), labels.to(device) # Forward pass outputs = model(inputs) loss = criterion(outputs, labels) # Backward pass and optimization optimizer.zero_grad() loss.backward() optimizer.step() # 打印初始训练结果 model.eval() with torch.no_grad(): initial_predictions = model(X_initial.to(device)).cpu().numpy() # 绘制初始训练结果图 plt.plot(y_initial.numpy(), label='Actual data') plt.plot(initial_predictions, label='Initial predictions', color='red') plt.title('Initial Training Results') plt.xlabel('Time step') plt.ylabel('y') plt.legend() plt.show() # 计算并存储 Fisher 信息 fisher_info = calculate_fisher(model, initial_dataloader, device) # 模拟增量学习的场景,生成新时序数据并进行增量学习 X_new, y_new = generate_time_series_data(50, 10) new_dataset = torch.utils.data.TensorDataset(X_new, y_new) new_dataloader = torch.utils.data.DataLoader(new_dataset, batch_size=32, shuffle=True) # 使用 EWC 进行增量学习 for epoch in range(20): for inputs, labels in new_dataloader: inputs, labels = inputs.to(device), labels.to(device) # Forward pass outputs = model(inputs) loss = criterion(outputs, labels) # EWC loss if epoch > 0: current_weights = torch.cat([param.view(-1) for param in model.parameters()]) ewc_loss_value = ewc_loss(fisher_info, current_weights, prev_weights, 0.1) loss += ewc_loss_value # Backward pass and optimization optimizer.zero_grad() loss.backward() optimizer.step() # 保存当前权重用于下一次计算 EWC 损失 prev_weights = torch.cat([param.view(-1) for param in model.parameters()]).detach().clone() # 打印增量学习后的结果 model.eval() with torch.no_grad(): new_predictions = model(X_initial.to(device)).cpu().numpy() # 绘制增量学习后的结果图 plt.plot(y_initial.numpy(), label='Actual data') plt.plot(new_predictions, label='Incremental learning predictions', color='green') plt.title('Incremental Learning Results') plt.xlabel('Time step') plt.ylabel('y') plt.legend() plt.show()
讯享网

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容,请联系我们,一经查实,本站将立刻删除。
如需转载请保留出处:https://51itzy.com/kjqy/40487.html