Source code for espaloma.nn.readout.base_readout

# =============================================================================
# IMPORTS
# =============================================================================
import abc

import torch


# =============================================================================
# BASE CLASSES
# =============================================================================
[docs]class BaseReadout(abc.ABC, torch.nn.Module): """Base class for readout function."""
[docs] def __init__(self): super(BaseReadout, self).__init__()
[docs] @abc.abstractmethod def forward(self, g, x=None, *args, **kwargs): raise NotImplementedError
def _forward(self, g, x, *args, **kwargs): raise NotImplementedError