import dgl
import numpy as np
import torch
from .offmol_indices import improper_torsion_indices
from ..graph import Graph
[docs]def regenerate_impropers(g: Graph, improper_def='smirnoff'):
"""
Method to regenerate the improper nodes according to the specified
method of permuting the impropers. Modifies the esp.Graph's heterograph
in place and returns the new heterograph.
NOTE: This will clear the data on all n4_improper nodes, including
previously generated improper from JanossyPoolingImproper.
"""
## First get rid of the old nodes/edges
hg = g.heterograph
hg = dgl.remove_nodes(hg, hg.nodes('n4_improper'), 'n4_improper')
## Generate new improper torsion permutations
idxs = improper_torsion_indices(g.mol, improper_def)
if len(idxs) == 0:
return g
## Add new nodes of type n4_improper (one for each permut)
hg = dgl.add_nodes(hg, idxs.shape[0], ntype='n4_improper')
## New edges b/n improper permuts and n1 nodes
permut_ids = np.arange(idxs.shape[0])
for i in range(4):
n1_ids = idxs[:,i]
# edge from improper node to n1 node
outgoing_etype = ('n4_improper', f'n4_improper_has_{i}_n1', 'n1')
hg = dgl.add_edges(hg, permut_ids, n1_ids, etype=outgoing_etype)
# edge from n1 to improper
incoming_etype = ('n1', f'n1_as_{i}_in_n4_improper', 'n4_improper')
hg = dgl.add_edges(hg, n1_ids, permut_ids, etype=incoming_etype)
## New edges b/n improper permuts and the graph (for global pooling)
# edge from improper node to graph
outgoing_etype = ('n4_improper', f'n4_improper_in_g', 'g')
hg = dgl.add_edges(hg, permut_ids, np.zeros_like(permut_ids),
etype=outgoing_etype)
# edge from graph to improper nodes
incoming_etype = ('g', 'g_has_n4_improper', 'n4_improper')
hg = dgl.add_edges(hg, np.zeros_like(permut_ids), permut_ids,
etype=incoming_etype)
hg.nodes['n4_improper'].data['idxs'] = torch.tensor(idxs)
g.heterograph = hg
return g # hg