2025年使用增量学习中EWC方法来做回归简单示例

使用增量学习中EWC方法来做回归简单示例import torch import torch nn as nn import torch optim as optim import numpy as np import matplotlib pyplot as plt 生成随机时序数据 def generate time series data num samples sequence length X

大家好,我是讯享网,很高兴认识大家。
import torch import torch.nn as nn import torch.optim as optim import numpy as np import matplotlib.pyplot as plt # 生成随机时序数据 def generate_time_series_data(num_samples, sequence_length): X = torch.randn(num_samples, sequence_length, 1) # Assuming a univariate time series y = (X.sum(dim=(1, 2)) + 0.1 * torch.randn(num_samples)).view(-1, 1) return X, y # 定义简单的LSTM模型 class SimpleLSTM(nn.Module): def __init__(self, input_size, hidden_size, output_size): super(SimpleLSTM, self).__init__() self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True) self.fc = nn.Linear(hidden_size, output_size) def forward(self, x): _, (h_n, _) = self.lstm(x) x = self.fc(h_n[-1]) return x # 计算 Fisher 信息 def calculate_fisher(model, dataloader, device): fisher_info = [] model.eval() criterion = nn.MSELoss() for inputs, labels in dataloader: inputs, labels = inputs.to(device), labels.to(device) # Forward pass outputs = model(inputs) # Compute loss and backward pass loss = criterion(outputs, labels) model.zero_grad() loss.backward() # Extract gradients from the model parameters gradients = [param.grad.flatten().detach().cpu().numpy() for param in model.parameters()] fisher_info.append(np.square(np.concatenate(gradients)) / len(dataloader.dataset)) fisher_info = np.mean(fisher_info, axis=0) return fisher_info # 定义 EWC 损失 def ewc_loss(fisher_information, weight, weight_old, lambda_): return lambda_ / 2 * torch.sum(torch.tensor(fisher_information) * (weight - weight_old) 2) # 初始化模型、数据和优化器 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = SimpleLSTM(input_size=1, hidden_size=64, output_size=1).to(device) optimizer = optim.SGD(model.parameters(), lr=0.01) criterion = nn.MSELoss() # 生成初始时序数据并进行初始训练 X_initial, y_initial = generate_time_series_data(100, 10) initial_dataset = torch.utils.data.TensorDataset(X_initial, y_initial) initial_dataloader = torch.utils.data.DataLoader(initial_dataset, batch_size=32, shuffle=True) # 初始训练 for epoch in range(50): for inputs, labels in initial_dataloader: inputs, labels = inputs.to(device), labels.to(device) # Forward pass outputs = model(inputs) loss = criterion(outputs, labels) # Backward pass and optimization optimizer.zero_grad() loss.backward() optimizer.step() # 打印初始训练结果 model.eval() with torch.no_grad(): initial_predictions = model(X_initial.to(device)).cpu().numpy() # 绘制初始训练结果图 plt.plot(y_initial.numpy(), label='Actual data') plt.plot(initial_predictions, label='Initial predictions', color='red') plt.title('Initial Training Results') plt.xlabel('Time step') plt.ylabel('y') plt.legend() plt.show() # 计算并存储 Fisher 信息 fisher_info = calculate_fisher(model, initial_dataloader, device) # 模拟增量学习的场景,生成新时序数据并进行增量学习 X_new, y_new = generate_time_series_data(50, 10) new_dataset = torch.utils.data.TensorDataset(X_new, y_new) new_dataloader = torch.utils.data.DataLoader(new_dataset, batch_size=32, shuffle=True) # 使用 EWC 进行增量学习 for epoch in range(20): for inputs, labels in new_dataloader: inputs, labels = inputs.to(device), labels.to(device) # Forward pass outputs = model(inputs) loss = criterion(outputs, labels) # EWC loss if epoch > 0: current_weights = torch.cat([param.view(-1) for param in model.parameters()]) ewc_loss_value = ewc_loss(fisher_info, current_weights, prev_weights, 0.1) loss += ewc_loss_value # Backward pass and optimization optimizer.zero_grad() loss.backward() optimizer.step() # 保存当前权重用于下一次计算 EWC 损失 prev_weights = torch.cat([param.view(-1) for param in model.parameters()]).detach().clone() # 打印增量学习后的结果 model.eval() with torch.no_grad(): new_predictions = model(X_initial.to(device)).cpu().numpy() # 绘制增量学习后的结果图 plt.plot(y_initial.numpy(), label='Actual data') plt.plot(new_predictions, label='Incremental learning predictions', color='green') plt.title('Incremental Learning Results') plt.xlabel('Time step') plt.ylabel('y') plt.legend() plt.show()

讯享网


讯享网

小讯
上一篇 2025-03-23 18:06
下一篇 2025-01-06 10:55

相关推荐

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容,请联系我们,一经查实,本站将立刻删除。
如需转载请保留出处:https://51itzy.com/kjqy/40487.html