# =============================================================================
# IMPORTS
# =============================================================================
import torch
# =============================================================================
# UTILITY FUNCTIONS
# =============================================================================
[docs]def reduce_stack(msg, out):
""" Copy massage and stack. """
def _reduce_stack(nodes, msg=msg, out=out):
return {out: nodes.mailbox[msg]}
return _reduce_stack
[docs]def copy_src(src, out):
""" Copy source of an edge. """
def _copy_src(edges, src=src, out=out):
return {out: edges.src[src].clone()}
return _copy_src
# =============================================================================
# SINGLE GEOMETRY ENTITY
# =============================================================================
[docs]def distance(x0, x1):
""" Distance. """
return torch.norm(x0 - x1, p=2, dim=-1)
def _angle(r0, r1):
""" Angle between vectors. """
angle = torch.atan2(
torch.norm(torch.cross(r0, r1), p=2, dim=-1),
torch.sum(torch.mul(r0, r1), dim=-1),
)
return angle
[docs]def angle(x0, x1, x2):
""" Angle between three points. """
left = x1 - x0
right = x1 - x2
return _angle(left, right)
def _dihedral(r0, r1):
""" Dihedral between normal vectors. """
return _angle(r0, r1)
[docs]def dihedral(
x0: torch.Tensor, x1: torch.Tensor, x2: torch.Tensor, x3: torch.Tensor
) -> torch.Tensor:
"""Dihedral between four points.
Reference
---------
Closely follows implementation in Yutong Zhao's timemachine:
https://github.com/proteneer/timemachine/blob/1a0ab45e605dc1e28c44ea90f38cb0dedce5c4db/timemachine/potentials/bonded.py#L152-L199
"""
# check input shapes
assert x0.shape == x1.shape == x2.shape == x3.shape
# compute displacements 0->1, 2->1, 2->3
r01 = x1 - x0 + torch.randn_like(x0) * 1e-5
r21 = x1 - x2 + torch.randn_like(x0) * 1e-5
r23 = x3 - x2 + torch.randn_like(x0) * 1e-5
# compute normal planes
n1 = torch.cross(r01, r21)
n2 = torch.cross(r21, r23)
rkj_normed = r21 / torch.norm(r21, dim=-1, keepdim=True)
y = torch.sum(torch.mul(torch.cross(n1, n2), rkj_normed), dim=-1)
x = torch.sum(torch.mul(n1, n2), dim=-1)
# choose quadrant correctly
theta = torch.atan2(y, x)
return theta
# =============================================================================
# GEOMETRY IN HYPERNODES
# =============================================================================
[docs]def apply_bond(nodes):
""" Bond length in nodes. """
return {"x": distance(x0=nodes.data["xyz0"], x1=nodes.data["xyz1"])}
[docs]def apply_angle(nodes):
""" Angle values in nodes. """
return {
"x": angle(
x0=nodes.data["xyz0"],
x1=nodes.data["xyz1"],
x2=nodes.data["xyz2"],
),
"x_left": distance(
x0=nodes.data["xyz1"],
x1=nodes.data["xyz0"],
),
"x_right": distance(
x0=nodes.data["xyz1"],
x1=nodes.data["xyz2"],
),
"x_between": distance(
x0=nodes.data["xyz0"],
x1=nodes.data["xyz2"],
),
}
[docs]def apply_torsion(nodes):
""" Torsion dihedrals in nodes. """
return {
"x": dihedral(
x0=nodes.data["xyz0"],
x1=nodes.data["xyz1"],
x2=nodes.data["xyz2"],
x3=nodes.data["xyz3"],
),
"x_bond_left": distance(
x0=nodes.data["xyz0"],
x1=nodes.data["xyz1"],
),
"x_bond_center": distance(
x0=nodes.data["xyz1"],
x1=nodes.data["xyz2"],
),
"x_bond_right": distance(
x0=nodes.data["xyz2"],
x1=nodes.data["xyz3"],
),
"x_angle_left": angle(
x0=nodes.data["xyz0"],
x1=nodes.data["xyz1"],
x2=nodes.data["xyz2"],
),
"x_angle_right": angle(
x0=nodes.data["xyz1"],
x1=nodes.data["xyz2"],
x2=nodes.data["xyz3"],
),
}
# =============================================================================
# GEOMETRY IN GRAPH
# =============================================================================
# NOTE:
# The following functions modify graphs in-place.
[docs]def geometry_in_graph(g):
"""Assign values to geometric entities in graphs.
Parameters
----------
g : `dgl.DGLHeteroGraph`
Input graph.
Returns
-------
g : `dgl.DGLHeteroGraph`
Output graph.
Notes
-----
This function modifies graphs in-place.
"""
import dgl
# Copy coordinates to higher-order nodes.
g.multi_update_all(
{
**{
"n1_as_%s_in_n%s"
% (pos_idx, big_idx): (
dgl.function.copy_src(src="xyz", out="m_xyz%s" % pos_idx),
dgl.function.sum(
msg="m_xyz%s" % pos_idx, out="xyz%s" % pos_idx
),
)
for big_idx in range(2, 5)
for pos_idx in range(big_idx)
},
**{
"n1_as_%s_in_%s"
% (pos_idx, term): (
dgl.function.copy_src(src="xyz", out="m_xyz%s" % pos_idx),
dgl.function.sum(
msg="m_xyz%s" % pos_idx, out="xyz%s" % pos_idx
),
)
for term in ["nonbonded", "onefour"]
for pos_idx in [0, 1]
},
**{
"n1_as_%s_in_%s"
% (pos_idx, term): (
dgl.function.copy_src(src="xyz", out="m_xyz%s" % pos_idx),
dgl.function.sum(
msg="m_xyz%s" % pos_idx, out="xyz%s" % pos_idx
),
)
for term in ["n4_improper"]
for pos_idx in [0, 1, 2, 3]
},
},
cross_reducer="sum",
)
# apply geometry functions
g.apply_nodes(apply_bond, ntype="n2")
g.apply_nodes(apply_angle, ntype="n3")
if g.number_of_nodes("n4") > 0:
g.apply_nodes(apply_torsion, ntype="n4")
# copy coordinates to nonbonded
if g.number_of_nodes("nonbonded") > 0:
g.apply_nodes(apply_bond, ntype="nonbonded")
if g.number_of_nodes("onefour") > 0:
g.apply_nodes(apply_bond, ntype="onefour")
if g.number_of_nodes("n4_improper") > 0:
g.apply_nodes(apply_torsion, ntype="n4_improper")
return g
[docs]class GeometryInGraph(torch.nn.Module):
[docs] def __init__(self, *args, **kwargs):
super(GeometryInGraph, self).__init__()
self.args = args
self.kwargs = kwargs
[docs] def forward(self, g):
return geometry_in_graph(g, *self.args, **self.kwargs)