Source code for espaloma.mm.geometry

# =============================================================================
# IMPORTS
# =============================================================================
import torch

# =============================================================================
# UTILITY FUNCTIONS
# =============================================================================
[docs]def reduce_stack(msg, out): """ Copy massage and stack. """ def _reduce_stack(nodes, msg=msg, out=out): return {out: nodes.mailbox[msg]} return _reduce_stack
[docs]def copy_src(src, out): """ Copy source of an edge. """ def _copy_src(edges, src=src, out=out): return {out: edges.src[src].clone()} return _copy_src
# ============================================================================= # SINGLE GEOMETRY ENTITY # =============================================================================
[docs]def distance(x0, x1): """ Distance. """ return torch.norm(x0 - x1, p=2, dim=-1)
def _angle(r0, r1): """ Angle between vectors. """ angle = torch.atan2( torch.norm(torch.cross(r0, r1), p=2, dim=-1), torch.sum(torch.mul(r0, r1), dim=-1), ) return angle
[docs]def angle(x0, x1, x2): """ Angle between three points. """ left = x1 - x0 right = x1 - x2 return _angle(left, right)
def _dihedral(r0, r1): """ Dihedral between normal vectors. """ return _angle(r0, r1)
[docs]def dihedral( x0: torch.Tensor, x1: torch.Tensor, x2: torch.Tensor, x3: torch.Tensor ) -> torch.Tensor: """Dihedral between four points. Reference --------- Closely follows implementation in Yutong Zhao's timemachine: https://github.com/proteneer/timemachine/blob/1a0ab45e605dc1e28c44ea90f38cb0dedce5c4db/timemachine/potentials/bonded.py#L152-L199 """ # check input shapes assert x0.shape == x1.shape == x2.shape == x3.shape # compute displacements 0->1, 2->1, 2->3 r01 = x1 - x0 + torch.randn_like(x0) * 1e-5 r21 = x1 - x2 + torch.randn_like(x0) * 1e-5 r23 = x3 - x2 + torch.randn_like(x0) * 1e-5 # compute normal planes n1 = torch.cross(r01, r21) n2 = torch.cross(r21, r23) rkj_normed = r21 / torch.norm(r21, dim=-1, keepdim=True) y = torch.sum(torch.mul(torch.cross(n1, n2), rkj_normed), dim=-1) x = torch.sum(torch.mul(n1, n2), dim=-1) # choose quadrant correctly theta = torch.atan2(y, x) return theta
# ============================================================================= # GEOMETRY IN HYPERNODES # =============================================================================
[docs]def apply_bond(nodes): """ Bond length in nodes. """ return {"x": distance(x0=nodes.data["xyz0"], x1=nodes.data["xyz1"])}
[docs]def apply_angle(nodes): """ Angle values in nodes. """ return { "x": angle( x0=nodes.data["xyz0"], x1=nodes.data["xyz1"], x2=nodes.data["xyz2"], ), "x_left": distance( x0=nodes.data["xyz1"], x1=nodes.data["xyz0"], ), "x_right": distance( x0=nodes.data["xyz1"], x1=nodes.data["xyz2"], ), "x_between": distance( x0=nodes.data["xyz0"], x1=nodes.data["xyz2"], ), }
[docs]def apply_torsion(nodes): """ Torsion dihedrals in nodes. """ return { "x": dihedral( x0=nodes.data["xyz0"], x1=nodes.data["xyz1"], x2=nodes.data["xyz2"], x3=nodes.data["xyz3"], ), "x_bond_left": distance( x0=nodes.data["xyz0"], x1=nodes.data["xyz1"], ), "x_bond_center": distance( x0=nodes.data["xyz1"], x1=nodes.data["xyz2"], ), "x_bond_right": distance( x0=nodes.data["xyz2"], x1=nodes.data["xyz3"], ), "x_angle_left": angle( x0=nodes.data["xyz0"], x1=nodes.data["xyz1"], x2=nodes.data["xyz2"], ), "x_angle_right": angle( x0=nodes.data["xyz1"], x1=nodes.data["xyz2"], x2=nodes.data["xyz3"], ), }
# ============================================================================= # GEOMETRY IN GRAPH # ============================================================================= # NOTE: # The following functions modify graphs in-place.
[docs]def geometry_in_graph(g): """Assign values to geometric entities in graphs. Parameters ---------- g : `dgl.DGLHeteroGraph` Input graph. Returns ------- g : `dgl.DGLHeteroGraph` Output graph. Notes ----- This function modifies graphs in-place. """ import dgl # Copy coordinates to higher-order nodes. g.multi_update_all( { **{ "n1_as_%s_in_n%s" % (pos_idx, big_idx): ( dgl.function.copy_src(src="xyz", out="m_xyz%s" % pos_idx), dgl.function.sum( msg="m_xyz%s" % pos_idx, out="xyz%s" % pos_idx ), ) for big_idx in range(2, 5) for pos_idx in range(big_idx) }, **{ "n1_as_%s_in_%s" % (pos_idx, term): ( dgl.function.copy_src(src="xyz", out="m_xyz%s" % pos_idx), dgl.function.sum( msg="m_xyz%s" % pos_idx, out="xyz%s" % pos_idx ), ) for term in ["nonbonded", "onefour"] for pos_idx in [0, 1] }, **{ "n1_as_%s_in_%s" % (pos_idx, term): ( dgl.function.copy_src(src="xyz", out="m_xyz%s" % pos_idx), dgl.function.sum( msg="m_xyz%s" % pos_idx, out="xyz%s" % pos_idx ), ) for term in ["n4_improper"] for pos_idx in [0, 1, 2, 3] }, }, cross_reducer="sum", ) # apply geometry functions g.apply_nodes(apply_bond, ntype="n2") g.apply_nodes(apply_angle, ntype="n3") if g.number_of_nodes("n4") > 0: g.apply_nodes(apply_torsion, ntype="n4") # copy coordinates to nonbonded if g.number_of_nodes("nonbonded") > 0: g.apply_nodes(apply_bond, ntype="nonbonded") if g.number_of_nodes("onefour") > 0: g.apply_nodes(apply_bond, ntype="onefour") if g.number_of_nodes("n4_improper") > 0: g.apply_nodes(apply_torsion, ntype="n4_improper") return g
[docs]class GeometryInGraph(torch.nn.Module):
[docs] def __init__(self, *args, **kwargs): super(GeometryInGraph, self).__init__() self.args = args self.kwargs = kwargs
[docs] def forward(self, g): return geometry_in_graph(g, *self.args, **self.kwargs)