Source code for espaloma.nn.readout.janossy

# =============================================================================
# IMPORTS
# =============================================================================
import torch

import espaloma as esp


# =============================================================================
# MODULE CLASSES
# =============================================================================
[docs]class JanossyPooling(torch.nn.Module): """Janossy pooling (arXiv:1811.01900) to average node representation for higher-order nodes. """
[docs] def __init__( self, config, in_features, out_features={ 1: ["sigma", "epsilon", "q"], 2: ["k", "eq"], 3: ["k", "eq"], 4: ["k", "eq"], }, out_features_dimensions=-1, pool=torch.add, ): super(JanossyPooling, self).__init__() # if users specify out features as lists, # assume dimensions to be all zero for level in out_features.keys(): if isinstance(out_features[level], list): out_features[level] = dict( zip(out_features[level], [1 for _ in out_features[level]]) ) # bookkeeping self.out_features = out_features self.levels = [key for key in out_features.keys() if key != 1] self.pool = pool # get output features mid_features = [x for x in config if isinstance(x, int)][-1] # set up networks for level in self.levels: # set up individual sequential networks setattr( self, "sequential_%s" % level, esp.nn.sequential._Sequential( in_features=in_features * level, config=config, layer=torch.nn.Linear, ), ) for feature, dimension in self.out_features[level].items(): setattr( self, "f_out_%s_to_%s" % (level, feature), torch.nn.Linear( mid_features, dimension, ), ) if 1 not in self.out_features: return # atom level self.sequential_1 = esp.nn.sequential._Sequential( in_features=in_features, config=config, layer=torch.nn.Linear ) for feature, dimension in self.out_features[1].items(): setattr( self, "f_out_1_to_%s" % feature, torch.nn.Linear( mid_features, dimension, ), )
[docs] def forward(self, g): """Forward pass. Parameters ---------- g : dgl.DGLHeteroGraph, input graph. """ import dgl # copy g.multi_update_all( { "n1_as_%s_in_n%s" % (relationship_idx, big_idx): ( dgl.function.copy_src("h", "m%s" % relationship_idx), dgl.function.mean( "m%s" % relationship_idx, "h%s" % relationship_idx ), ) for big_idx in self.levels for relationship_idx in range(big_idx) }, cross_reducer="sum", ) # pool for big_idx in self.levels: if g.number_of_nodes("n%s" % big_idx) == 0: continue g.apply_nodes( func=lambda nodes: { feature: getattr( self, "f_out_%s_to_%s" % (big_idx, feature) )( self.pool( getattr(self, "sequential_%s" % big_idx)( None, torch.cat( [ nodes.data["h%s" % relationship_idx] for relationship_idx in range(big_idx) ], dim=1, ), ), getattr(self, "sequential_%s" % big_idx)( None, torch.cat( [ nodes.data["h%s" % relationship_idx] for relationship_idx in range( big_idx - 1, -1, -1 ) ], dim=1, ), ), ), ) for feature in self.out_features[big_idx].keys() }, ntype="n%s" % big_idx, ) if 1 not in self.out_features: return g # atom level g.apply_nodes( func=lambda nodes: { feature: getattr(self, "f_out_1_to_%s" % feature)( self.sequential_1(g=None, x=nodes.data["h"]) ) for feature in self.out_features[1].keys() }, ntype="n1", ) return g
[docs]class JanossyPoolingImproper(torch.nn.Module): """Janossy pooling (arXiv:1811.01900) to average node representation for improper torsions. """
[docs] def __init__( self, config, in_features, out_features={ "k": 6, }, out_features_dimensions=-1 ): super(JanossyPoolingImproper, self).__init__() # if users specify out features as lists, # assume dimensions to be all zero # bookkeeping self.out_features = out_features self.levels = ["n4_improper"] # get output features mid_features = [x for x in config if isinstance(x, int)][-1] # set up networks for level in self.levels: # set up individual sequential networks setattr( self, "sequential_%s" % level, esp.nn.sequential._Sequential( in_features=4 * in_features, config=config, layer=torch.nn.Linear, ), ) for feature, dimension in self.out_features.items(): setattr( self, "f_out_%s_to_%s" % (level, feature), torch.nn.Linear( mid_features, dimension, ), )
[docs] def forward(self, g): """Forward pass. Parameters ---------- g : dgl.DGLHeteroGraph, input graph. """ import dgl # copy g.multi_update_all( { "n1_as_%s_in_%s" % (relationship_idx, big_idx): ( dgl.function.copy_src("h", "m%s" % relationship_idx), dgl.function.mean( "m%s" % relationship_idx, "h%s" % relationship_idx ), ) for big_idx in self.levels for relationship_idx in range(4) }, cross_reducer="sum", ) if g.number_of_nodes("n4_improper") == 0: return g # pool # sum over three cyclic permutations of "h0", "h2", "h3", assuming "h1" is the central atom in the improper # following the smirnoff trefoil convention [(0, 1, 2, 3), (2, 1, 3, 0), (3, 1, 0, 2)] # https://github.com/openff.toolkit/openff.toolkit/blob/166c9864de3455244bd80b2c24656bd7dda3ae2d/openff.toolkit/typing/engines/smirnoff/parameters.py#L3326-L3360 ## Set different permutations based on which definition of impropers ## are being used permuts = [(0, 1, 2, 3), (2, 1, 3, 0), (3, 1, 0, 2)] stack_permuts = lambda nodes, p: \ torch.cat([nodes.data[f'h{i}'] for i in p], dim=1) for big_idx in self.levels: inner_net = getattr(self, f'sequential_{big_idx}') g.apply_nodes(func=lambda nodes: { feature: getattr(self, f'f_out_{big_idx}_to_{feature}')( torch.sum( torch.stack( [inner_net(g=None, x=stack_permuts(nodes, p)) \ for p in permuts], dim=0 ), dim=0 ) ) for feature in self.out_features.keys() }, ntype=big_idx) return g
[docs]class JanossyPoolingWithSmirnoffImproper(torch.nn.Module): """Janossy pooling (arXiv:1811.01900) to average node representation for improper torsions. """
[docs] def __init__( self, config, in_features, out_features={ "k": 6, }, out_features_dimensions=-1 ): super(JanossyPoolingWithSmirnoffImproper, self).__init__() # if users specify out features as lists, # assume dimensions to be all zero # bookkeeping self.out_features = out_features self.levels = ["n4_improper"] # get output features mid_features = [x for x in config if isinstance(x, int)][-1] # set up networks for level in self.levels: # set up individual sequential networks setattr( self, "sequential_%s" % level, esp.nn.sequential._Sequential( in_features=4 * in_features, config=config, layer=torch.nn.Linear, ), ) for feature, dimension in self.out_features.items(): setattr( self, "f_out_%s_to_%s" % (level, feature), torch.nn.Linear( mid_features, dimension, ), )
[docs] def forward(self, g): """Forward pass. Parameters ---------- g : dgl.DGLHeteroGraph, input graph. """ import dgl # copy g.multi_update_all( { "n1_as_%s_in_%s" % (relationship_idx, big_idx): ( dgl.function.copy_src("h", "m%s" % relationship_idx), dgl.function.mean( "m%s" % relationship_idx, "h%s" % relationship_idx ), ) for big_idx in self.levels for relationship_idx in range(4) }, cross_reducer="sum", ) if g.number_of_nodes("n4_improper") == 0: return g # pool # sum over three cyclic permutations of "h0", "h2", "h3", assuming "h1" is the central atom in the improper # following the smirnoff trefoil convention [(0, 1, 2, 3), (2, 1, 3, 0), (3, 1, 0, 2)] # https://github.com/openff.toolkit/openff.toolkit/blob/166c9864de3455244bd80b2c24656bd7dda3ae2d/openff.toolkit/typing/engines/smirnoff/parameters.py#L3326-L3360 ## Set different permutations based on which definition of impropers ## are being used permuts = [(0, 1, 2, 3), (0, 2, 3, 1), (0, 3, 1, 2)] stack_permuts = lambda nodes, p: \ torch.cat([nodes.data[f'h{i}'] for i in p], dim=1) for big_idx in self.levels: inner_net = getattr(self, f'sequential_{big_idx}') g.apply_nodes(func=lambda nodes: { feature: getattr(self, f'f_out_{big_idx}_to_{feature}')( torch.sum( torch.stack( [inner_net(g=None, x=stack_permuts(nodes, p)) \ for p in permuts], dim=0 ), dim=0 ) ) for feature in self.out_features.keys() }, ntype=big_idx) return g
[docs]class JanossyPoolingNonbonded(torch.nn.Module): """Janossy pooling (arXiv:1811.01900) to average node representation for improper torsions. """
[docs] def __init__( self, config, in_features, out_features={"sigma": 1, "epsilon": 1}, out_features_dimensions=-1, ): super(JanossyPoolingNonbonded, self).__init__() # if users specify out features as lists, # assume dimensions to be all zero # bookkeeping self.out_features = out_features self.levels = ["onefour", "nonbonded"] # get output features mid_features = [x for x in config if isinstance(x, int)][-1] # set up networks for level in self.levels: # set up individual sequential networks setattr( self, "sequential_%s" % level, esp.nn.sequential._Sequential( in_features=2 * in_features, config=config, layer=torch.nn.Linear, ), ) for feature, dimension in self.out_features.items(): setattr( self, "f_out_%s_to_%s" % (level, feature), torch.nn.Linear( mid_features, dimension, ), )
[docs] def forward(self, g): """Forward pass. Parameters ---------- g : dgl.DGLHeteroGraph, input graph. """ # copy g.multi_update_all( { "n1_as_%s_in_%s" % (relationship_idx, big_idx): ( dgl.function.copy_src("h", "m%s" % relationship_idx), dgl.function.mean( "m%s" % relationship_idx, "h%s" % relationship_idx ), ) for big_idx in self.levels for relationship_idx in range(2) }, cross_reducer="sum", ) for big_idx in self.levels: g.apply_nodes( func=lambda nodes: { feature: getattr( self, "f_out_%s_to_%s" % (big_idx, feature) )( torch.sum( torch.stack( [ getattr(self, "sequential_%s" % big_idx)( g=None, x=torch.cat( [ nodes.data["h0"], nodes.data["h1"], ], dim=1, ), ), getattr(self, "sequential_%s" % big_idx)( g=None, x=torch.cat( [ nodes.data["h1"], nodes.data["h0"], ], dim=1, ), ), ], dim=0, ), dim=0, ) ) for feature in self.out_features.keys() }, ntype=big_idx, ) return g
[docs]class ExpCoefficients(torch.nn.Module):
[docs] def forward(self, g): import math g.nodes["n2"].data["coefficients"] = ( g.nodes["n2"].data["log_coefficients"].exp() ) g.nodes["n3"].data["coefficients"] = ( g.nodes["n3"].data["log_coefficients"].exp() ) return g
[docs]class LinearMixtureToOriginal(torch.nn.Module):
[docs] def forward(self, g): import math ( g.nodes["n2"].data["k"], g.nodes["n2"].data["eq"], ) = esp.mm.functional.linear_mixture_to_original( g.nodes["n2"].data["coefficients"][:, 0][:, None], g.nodes["n2"].data["coefficients"][:, 1][:, None], 1.5, 6.0, ) ( g.nodes["n3"].data["k"], g.nodes["n3"].data["eq"], ) = esp.mm.functional.linear_mixture_to_original( g.nodes["n3"].data["coefficients"][:, 0][:, None], g.nodes["n3"].data["coefficients"][:, 1][:, None], 0.0, math.pi, ) g.nodes["n3"].data.pop("coefficients") g.nodes["n2"].data.pop("coefficients") return g