Source code for titli.utils.loss

import torch
import torch.nn as nn

[docs] class RMSELoss(nn.Module):
[docs] def __init__(self): super().__init__()
[docs] def forward(self, x, z): squared_difference = (x - z) ** 2 mean = torch.mean(squared_difference) rmse = torch.sqrt(mean) return rmse