import torch
import torch.nn as nn
[docs]
class RMSELoss(nn.Module):
[docs]
def __init__(self):
super().__init__()
[docs]
def forward(self, x, z):
squared_difference = (x - z) ** 2
mean = torch.mean(squared_difference)
rmse = torch.sqrt(mean)
return rmse