Source code for espaloma.mm.energy

# =============================================================================
# IMPORTS
# =============================================================================
import torch

import espaloma as esp


# =============================================================================
# ENERGY IN HYPERNODES---BONDED
# =============================================================================
[docs]def apply_bond(nodes, suffix=""): """ Bond energy in nodes. """ # if suffix == '_ref': return { "u%s" % suffix: esp.mm.bond.harmonic_bond( x=nodes.data["x"], k=nodes.data["k%s" % suffix], eq=nodes.data["eq%s" % suffix], ) }
# else: # return { # 'u%s' % suffix: esp.mm.bond.harmonic_bond_re( # x=nodes.data['x'], # k=nodes.data['k%s' % suffix], # eq=nodes.data['eq%s' % suffix], # ) # }
[docs]def apply_angle(nodes, suffix=""): """ Angle energy in nodes. """ return { "u%s" % suffix: esp.mm.angle.harmonic_angle( x=nodes.data["x"], k=nodes.data["k%s" % suffix], eq=nodes.data["eq%s" % suffix], ) }
[docs]def apply_angle_ii(nodes, suffix=""): return { # "u_angle_high%s" # % suffix: esp.mm.angle.angle_high( # u_angle=nodes.data["u"], # k3=nodes.data["k3"], # k4=nodes.data["k4"], # ), "u_urey_bradley%s" % suffix: esp.mm.angle.urey_bradley( x_between=nodes.data["x_between"], coefficients=nodes.data["coefficients_urey_bradley"], phases=[0.0, 12.0], ), "u_bond_bond%s" % suffix: esp.mm.angle.bond_bond( u_left=nodes.data["u_left"], u_right=nodes.data["u_right"], k_bond_bond=nodes.data["k_bond_bond"], ), "u_bond_angle%s" % suffix: esp.mm.angle.bond_angle( u_left=nodes.data["u_left"], u_right=nodes.data["u_right"], u_angle=nodes.data["u"], k_bond_angle=nodes.data["k_bond_angle"], ), }
[docs]def apply_bond_ii(nodes, suffix=""): return { "u_bond_high%s" % suffix: esp.mm.bond.bond_high( u_bond=nodes.data["u"], k3=nodes.data["k3"], k4=nodes.data["k4"], ) }
[docs]def apply_torsion_ii(nodes, suffix=""): """ Torsion energy in nodes. """ return { "u_angle_angle%s" % suffix: esp.mm.torsion.angle_angle( u_angle_left=nodes.data["u_angle_left"], u_angle_right=nodes.data["u_angle_right"], k_angle_angle=nodes.data["k_angle_angle"], ), "u_angle_torsion%s" % suffix: esp.mm.torsion.angle_torsion( u_angle_left=nodes.data["u_angle_left"], u_angle_right=nodes.data["u_angle_right"], u_torsion=nodes.data["u"], k_angle_torsion=nodes.data["k_angle_torsion"], ), "u_angle_angle_torsion%s" % suffix: esp.mm.torsion.angle_angle_torsion( u_angle_left=nodes.data["u_angle_left"], u_angle_right=nodes.data["u_angle_right"], u_torsion=nodes.data["u"], k_angle_angle_torsion=nodes.data["k_angle_angle_torsion"], ), "u_bond_torsion%s" % suffix: esp.mm.torsion.bond_torsion( u_bond_left=nodes.data["u_bond_left"], u_bond_right=nodes.data["u_bond_right"], u_bond_center=nodes.data["u_bond_center"], u_torsion=nodes.data["u"], k_side_torsion=nodes.data["k_side_torsion"], k_center_torsion=nodes.data["k_center_torsion"], ), }
[docs]def apply_torsion(nodes, suffix=""): """ Torsion energy in nodes. """ if ( "phases%s" % suffix in nodes.data and "periodicity%s" % suffix in nodes.data ): return { "u%s" % suffix: esp.mm.torsion.periodic_torsion( x=nodes.data["x"], k=nodes.data["k%s" % suffix], phases=nodes.data["phases%s" % suffix], periodicity=nodes.data["periodicity%s" % suffix], ) } else: return { "u%s" % suffix: esp.mm.torsion.periodic_torsion( x=nodes.data["x"], k=nodes.data["k%s" % suffix], ) }
[docs]def apply_improper_torsion(nodes, suffix=""): """ Improper torsion energy in nodes. """ if ( "phases%s" % suffix in nodes.data and "periodicity%s" % suffix in nodes.data ): return { "u%s" % suffix: esp.mm.torsion.periodic_torsion( x=nodes.data["x"], k=nodes.data["k%s" % suffix], phases=nodes.data["phases%s" % suffix], periodicity=nodes.data["periodicity%s" % suffix], ) } else: return { "u%s" % suffix: esp.mm.torsion.periodic_torsion( x=nodes.data["x"], k=nodes.data["k%s" % suffix], ) }
[docs]def apply_bond_gaussian(nodes, suffix=""): """ Bond energy in nodes. """ # if suffix == '_ref': return { "u%s" % suffix: esp.mm.bond.gaussian_bond( x=nodes.data["x"], coefficients=nodes.data["coefficients%s" % suffix], ) }
[docs]def apply_bond_linear_mixture(nodes, suffix="", phases=[0.0, 1.0]): """ Bond energy in nodes. """ # if suffix == '_ref': return { "u%s" % suffix: esp.mm.bond.linear_mixture_bond( x=nodes.data["x"], coefficients=nodes.data["coefficients%s" % suffix], phases=phases, ) }
[docs]def apply_angle_linear_mixture(nodes, suffix="", phases=[0.0, 1.0]): """ Bond energy in nodes. """ # if suffix == '_ref': return { "u%s" % suffix: esp.mm.angle.linear_mixture_angle( x=nodes.data["x"], coefficients=nodes.data["coefficients%s" % suffix], phases=phases, ) }
# ============================================================================= # ENERGY IN HYPERNODES---NONBONDED # =============================================================================
[docs]def apply_nonbonded(nodes, scaling=1.0, suffix=""): """ Nonbonded in nodes. """ # TODO: should this be 9-6 or 12-6? return { "u%s" % suffix: scaling * esp.mm.nonbonded.lj_12_6( x=nodes.data["x"], sigma=nodes.data["sigma%s" % suffix], epsilon=nodes.data["epsilon%s" % suffix], ) }
[docs]def apply_coulomb(nodes, scaling=1.0, suffix=""): return { "u%s" % suffix: scaling * esp.mm.nonbonded.coulomb( x=nodes.data["x"], q=nodes.data["q"], ) }
# ============================================================================= # ENERGY IN GRAPH # =============================================================================
[docs]def energy_in_graph( g, suffix="", terms=["n2", "n3", "n4"], ): # "onefour", "nonbonded"]): """Calculate the energy of a small molecule given parameters and geometry. Parameters ---------- g : `dgl.DGLHeteroGraph` Input graph. Returns ------- g : `dgl.DGLHeteroGraph` Output graph. Notes ----- This function modifies graphs in-place. """ # TODO: this is all very restricted for now # we need to make this better import dgl if "n2" in terms: # apply energy function if "coefficients%s" % suffix in g.nodes["n2"].data: g.apply_nodes( lambda node: apply_bond_linear_mixture( node, suffix=suffix, phases=[1.5, 6.0] ), ntype="n2", ) else: g.apply_nodes( lambda node: apply_bond(node, suffix=suffix), ntype="n2", ) if "n3" in terms: if "coefficients%s" % suffix in g.nodes["n3"].data: import math g.apply_nodes( lambda node: apply_angle_linear_mixture( node, suffix=suffix, phases=[0.0, math.pi] ), ntype="n3", ) else: g.apply_nodes( lambda node: apply_angle(node, suffix=suffix), ntype="n3", ) if g.number_of_nodes("n4") > 0 and "n4" in terms: g.apply_nodes( lambda node: apply_torsion(node, suffix=suffix), ntype="n4", ) if g.number_of_nodes("n4_improper") > 0 and "n4_improper" in terms: g.apply_nodes( lambda node: apply_improper_torsion(node, suffix=suffix), ntype="n4_improper", ) # if g.number_of_nodes("nonbonded") > 0 and "nonbonded" in terms: # g.apply_nodes( # lambda node: apply_nonbonded(node, suffix=suffix), # ntype="nonbonded", # ) # if g.number_of_nodes("onefour") > 0 and "onefour" in terms: # g.apply_nodes( # lambda node: apply_nonbonded( # node, # suffix=suffix, # scaling=0.5, # ), # ntype="onefour", # ) if "nonbonded" in terms or "onefour" in terms: esp.mm.nonbonded.multiply_charges(g) if "nonbonded" in terms and g.number_of_nodes("nonbonded") > 0: g.apply_nodes( lambda node: apply_coulomb( node, suffix=suffix, scaling=1.0, ), ntype="nonbonded", ) if "onefour" in terms and g.number_of_nodes("onefour") > 0: g.apply_nodes( lambda node: apply_coulomb( node, suffix=suffix, # scaling=0.5, scaling=0.8333333333333334, ), ntype="onefour", ) # sum up energy # bonded g.multi_update_all( { "%s_in_g" % term: ( dgl.function.copy_src(src="u%s" % suffix, out="m_%s" % term), dgl.function.sum( msg="m_%s" % term, out="u_%s%s" % (term, suffix) ), ) for term in terms if "u%s" % suffix in g.nodes[term].data }, cross_reducer="sum", ) g.apply_nodes( lambda node: { "u%s" % suffix: sum( node.data["u_%s%s" % (term, suffix)] for term in terms if "u_%s%s" % (term, suffix) in node.data ) }, ntype="g", ) if "u0" in g.nodes["g"].data: g.apply_nodes( lambda node: {"u": node.data["u"] + node.data["u0"]}, ntype="g", ) return g
[docs]def energy_in_graph_ii( g, suffix="", ): if g.number_of_nodes("n3") > 0: g.apply_nodes( lambda node: apply_angle_ii(node, suffix=suffix), ntype="n3", ) g.apply_nodes( lambda node: { "u%s" % suffix: node.data["u%s" % suffix] + node.data["u_urey_bradley%s" % suffix] + node.data["u_bond_bond%s" % suffix] + node.data["u_bond_angle%s" % suffix] }, ntype="n3", ) if g.number_of_nodes("n4") > 0: g.apply_nodes( lambda node: apply_torsion_ii(node, suffix=suffix), ntype="n4", ) g.apply_nodes( lambda node: { "u%s" % suffix: node.data["u%s" % suffix] + node.data["u_angle_angle%s" % suffix] + node.data["u_angle_torsion%s" % suffix] + node.data["u_angle_angle_torsion%s" % suffix] + node.data["u_bond_torsion%s" % suffix] }, ntype="n4", ) return g
[docs]class EnergyInGraph(torch.nn.Module):
[docs] def __init__(self, *args, **kwargs): super(EnergyInGraph, self).__init__() self.args = args self.kwargs = kwargs
[docs] def forward(self, g): return energy_in_graph(g, *self.args, **self.kwargs)
[docs]class EnergyInGraphII(torch.nn.Module):
[docs] def __init__(self, *args, **kwargs): super(EnergyInGraphII, self).__init__() self.args = args self.kwargs = kwargs
[docs] def forward(self, g): return energy_in_graph_ii(g, *self.args, **self.kwargs)
[docs]class CarryII(torch.nn.Module):
[docs] def forward(self, g): import math import dgl g.multi_update_all( { "n2_as_0_in_n3": ( dgl.function.copy_src("u", "m_u_0"), dgl.function.sum("m_u_0", "u_left"), ), "n2_as_1_in_n3": ( dgl.function.copy_src("u", "m_u_1"), dgl.function.sum("m_u_1", "u_right"), ), "n2_as_0_in_n4": ( dgl.function.copy_src("u", "m_u_0"), dgl.function.sum("m_u_0", "u_bond_left"), ), "n2_as_1_in_n4": ( dgl.function.copy_src("u", "m_u_1"), dgl.function.sum("m_u_1", "u_bond_center"), ), "n2_as_2_in_n4": ( dgl.function.copy_src("u", "m_u_2"), dgl.function.sum("m_u_2", "u_bond_right"), ), "n3_as_0_in_n4": ( dgl.function.copy_src("u", "m3_u_0"), dgl.function.sum("m3_u_0", "u_angle_left"), ), "n3_as_1_in_n4": ( dgl.function.copy_src("u", "m3_u_1"), dgl.function.sum("m3_u_1", "u_angle_right"), ), }, cross_reducer="sum", ) return g