Source code for espaloma.nn.readout.node_typing

# =============================================================================
# IMPORTS
# =============================================================================
import torch

from espaloma.nn.readout.base_readout import BaseReadout


# =============================================================================
# MODULE CLASSES
# =============================================================================
[docs]class NodeTyping(BaseReadout): """Simple typing on homograph."""
[docs] def __init__(self, in_features, n_classes=100): super(NodeTyping, self).__init__() self.c = torch.nn.Linear(in_features, n_classes)
[docs] def forward(self, g): g.apply_nodes( ntype="n1", func=lambda node: {"nn_typing": self.c(node.data["h"])}, ) return g