Source code for espaloma.graphs.utils.read_heterogeneous_graph

""" Build heterogeneous graph from homogeneous ones.

"""
# =============================================================================
# IMPORTS
# =============================================================================
import numpy as np
import torch
from espaloma.graphs.utils import offmol_indices
from openff.toolkit.topology import Molecule
from typing import Dict

# =============================================================================
# UTILITY FUNCTIONS
# =============================================================================


[docs]def duplicate_index_ordering(indices: np.ndarray) -> np.ndarray: """For every (a,b,c,d) add a (d,c,b,a) TODO: is there a way to avoid this duplication? >>> indices = np.array([[0, 1, 2, 3], [1, 2, 3, 4]]) >>> duplicate_index_ordering(indices) array([[0, 1, 2, 3], [1, 2, 3, 4], [3, 2, 1, 0], [4, 3, 2, 1]]) """ return np.concatenate([indices, np.flip(indices, axis=-1)], axis=0)
[docs]def relationship_indices_from_offmol( offmol: Molecule ) -> Dict[str, torch.Tensor]: """Construct a dictionary that maps node names (like "n2") to torch tensors of indices Notes ----- * introduces 2x redundant indices (including (d,c,b,a) for every (a,b,c,d)) for compatibility with later processing """ idxs = dict() idxs["n1"] = offmol_indices.atom_indices(offmol) idxs["n2"] = offmol_indices.bond_indices(offmol) idxs["n3"] = offmol_indices.angle_indices(offmol) idxs["n4"] = offmol_indices.proper_torsion_indices(offmol) idxs["n4_improper"] = offmol_indices.improper_torsion_indices(offmol) if len(idxs["n4"]) == 0: idxs["n4"] = np.empty((0, 4)) if len(idxs["n4_improper"]) == 0: idxs["n4_improper"] = np.empty((0, 4)) # TODO: enumerate indices for coupling-term nodes also # TODO: big refactor of term names from "n4" to "proper_torsion", "improper_torsion", "angle_angle_coupling", etc. # TODO (discuss with YW) : I think "n1" and "n4_improper" shouldn't be 2x redundant in current scheme # (also, unclear why we need "n2", "n3", "n4" to be 2x redundant, but that's something to consider changing later) for key in ["n2", "n3", "n4"]: idxs[key] = duplicate_index_ordering(idxs[key]) # make them all torch.Tensors for key in idxs: idxs[key] = torch.from_numpy(idxs[key]) return idxs
[docs]def from_homogeneous_and_mol(g, offmol): r"""Build heterogeneous graph from homogeneous ones. Note ---- For now we name single node, two-, three, and four-, hypernodes as `n1`, `n2`, `n3`, and `n4`. These correspond to atom, bond, angle, and torsion nodes in chemical graphs. Parameters ---------- g : `espaloma.HomogeneousGraph` object the homogeneous graph to be translated. Returns ------- hg : `espaloma.HeterogeneousGraph` object the resulting heterogeneous graph. """ # initialize empty dictionary hg = {} # get adjacency matrix a = g.adjacency_matrix() # get all the indices idxs = relationship_indices_from_offmol(offmol) # make them all numpy idxs = {key: value.numpy() for key, value in idxs.items()} # also include n1 idxs["n1"] = np.arange(g.number_of_nodes())[:, None] # ========================= # neighboring relationships # ========================= # NOTE: # here we only define the neighboring relationship # on atom level hg[("n1", "n1_neighbors_n1", "n1")] = idxs["n2"] # build a mapping between indices and the ordering idxs_to_ordering = {} for term in ["n1", "n2", "n3", "n4", "n4_improper"]: idxs_to_ordering[term] = { tuple(subgraph_idxs): ordering for (ordering, subgraph_idxs) in enumerate(list(idxs[term])) } # =============================================== # relationships between nodes of different levels # =============================================== # NOTE: # here we define all the possible # 'has' and 'in' relationships. # TODO: # we'll test later to see if this adds too much overhead # for small_idx in range(1, 5): for big_idx in range(small_idx + 1, 5): for pos_idx in range(big_idx - small_idx + 1): hg[ ( "n%s" % small_idx, "n%s_as_%s_in_n%s" % (small_idx, pos_idx, big_idx), "n%s" % big_idx, ) ] = np.stack( [ np.array( [ idxs_to_ordering["n%s" % small_idx][tuple(x)] for x in idxs["n%s" % big_idx][ :, pos_idx : pos_idx + small_idx ] ] ), np.arange(idxs["n%s" % big_idx].shape[0]), ], axis=1, ) hg[ ( "n%s" % big_idx, "n%s_has_%s_n%s" % (big_idx, pos_idx, small_idx), "n%s" % small_idx, ) ] = np.stack( [ np.arange(idxs["n%s" % big_idx].shape[0]), np.array( [ idxs_to_ordering["n%s" % small_idx][tuple(x)] for x in idxs["n%s" % big_idx][ :, pos_idx : pos_idx + small_idx ] ] ), ], axis=1, ) # ====================================== # nonbonded terms # ====================================== # NOTE: everything is counted twice here # nonbonded is where # $A = AA = AAA = AAAA = 0$ # make dense a_ = a.to_dense().detach().numpy() idxs["nonbonded"] = np.stack( np.where( np.equal(a_ + a_ @ a_ + a_ @ a_ @ a_, 0.0) ), axis=-1, ) # onefour is the two ends of torsion # idxs["onefour"] = np.stack( # [ # idxs["n4"][:, 0], # idxs["n4"][:, 3], # ], # axis=1, # ) idxs["onefour"] = np.stack( np.where( np.equal(a_ + a_ @ a_, 0.0) * np.greater(a_ @ a_ @ a_, 0.0), ), axis=-1, ) # membership for term in ["nonbonded", "onefour"]: for pos_idx in [0, 1]: hg[(term, "%s_has_%s_n1" % (term, pos_idx), "n1")] = np.stack( [np.arange(idxs[term].shape[0]), idxs[term][:, pos_idx]], axis=-1, ) hg[("n1", "n1_as_%s_in_%s" % (pos_idx, term), term)] = np.stack( [ idxs[term][:, pos_idx], np.arange(idxs[term].shape[0]), ], axis=-1, ) # membership of n1 in n4_improper for term in ["n4_improper"]: for pos_idx in [0, 1, 2, 3]: hg[(term, "%s_has_%s_n1" % (term, pos_idx), "n1")] = np.stack( [np.arange(idxs[term].shape[0]), idxs[term][:, pos_idx]], axis=-1, ) hg[("n1", "n1_as_%s_in_%s" % (pos_idx, term), term)] = np.stack( [ idxs[term][:, pos_idx], np.arange(idxs[term].shape[0]), ], axis=-1, ) # ====================================== # relationships between nodes and graphs # ====================================== for term in [ "n1", "n2", "n3", "n4", "n4_improper", "nonbonded", "onefour", ]: hg[(term, "%s_in_g" % term, "g",)] = np.stack( [np.arange(len(idxs[term])), np.zeros(len(idxs[term]))], axis=1, ) hg[("g", "g_has_%s" % term, term)] = np.stack( [ np.zeros(len(idxs[term])), np.arange(len(idxs[term])), ], axis=1, ) import dgl hg = dgl.heterograph({key: list(value) for key, value in hg.items()}) hg.nodes["n1"].data["h0"] = g.ndata["h0"] # include indices in the nodes themselves for term in ["n1", "n2", "n3", "n4", "n4_improper", "onefour", "nonbonded"]: hg.nodes[term].data["idxs"] = torch.tensor(idxs[term]) return hg