# =============================================================================
# IMPORTS
# =============================================================================
import abc
import torch
# =============================================================================
# BASE CLASSES
# =============================================================================
[docs]class BaseReadout(abc.ABC, torch.nn.Module):
"""Base class for readout function."""
[docs] def __init__(self):
super(BaseReadout, self).__init__()
[docs] @abc.abstractmethod
def forward(self, g, x=None, *args, **kwargs):
raise NotImplementedError
def _forward(self, g, x, *args, **kwargs):
raise NotImplementedError