Source code for espaloma.nn.layers.dgl_legacy

""" Legacy models from DGL.

"""

# =============================================================================
# IMPORTS
# =============================================================================
import torch

# =============================================================================
# CONSTANT
# =============================================================================
DEFAULT_MODEL_KWARGS = {
    "SAGEConv": {"aggregator_type": "mean"},
    "GATConv": {"num_heads": 4},
    "TAGConv": {"k": 2},
}


# =============================================================================
# MODULE CLASSES
# =============================================================================
[docs]class GN(torch.nn.Module):
[docs] def __init__( self, in_features, out_features, model_name="GraphConv", kwargs={}, ): super(GN, self).__init__() from dgl.nn import pytorch as dgl_pytorch if kwargs == {}: if model_name in DEFAULT_MODEL_KWARGS: kwargs = DEFAULT_MODEL_KWARGS[model_name] self.gn = getattr(dgl_pytorch.conv, model_name)( in_features, out_features, **kwargs ) # register these properties here for downstream handling self.in_features = in_features self.out_features = out_features
[docs] def forward(self, g, x): return self.gn(g, x)
# ============================================================================= # MODULE FUNCTIONS # =============================================================================
[docs]def gn(model_name="GraphConv", kwargs={}): from dgl.nn import pytorch as dgl_pytorch if model_name == "GINConv": return lambda in_features, out_features: dgl_pytorch.conv.GINConv( apply_func=torch.nn.Linear(in_features, out_features), aggregator_type="sum", ) else: return lambda in_features, out_features: GN( in_features=in_features, out_features=out_features, model_name=model_name, kwargs=kwargs, )