Source code for espaloma.mm.nonbonded

# =============================================================================
# IMPORTS
# =============================================================================
import torch

# =============================================================================
# CONSTANTS
# =============================================================================
import espaloma as esp
from simtk import unit

# CODATA 2018
# ref https://en.wikipedia.org/wiki/Coulomb_constant
# Coulomb constant
K_E = (
    8.9875517923 * 1e9
    * unit.newton
    * unit.meter ** 2
    * unit.coulomb ** (-2)
    * esp.units.PARTICLE ** (-1)
).value_in_unit(esp.units.COULOMB_CONSTANT_UNIT)

# =============================================================================
# UTILITY FUNCTIONS FOR COMBINATION RULES FOR NONBONDED
# =============================================================================
[docs]def geometric_mean(msg="m", out="epsilon"): def _geometric_mean(nodes): return {out: torch.prod(nodes.mailbox[msg], dim=1).pow(0.5)} return _geometric_mean
[docs]def arithmetic_mean(msg="m", out="sigma"): def _arithmetic_mean(nodes): return {out: torch.sum(nodes.mailbox[msg], dim=1).mul(0.5)} return _arithmetic_mean
# ============================================================================= # COMBINATION RULES FOR NONBONDED # =============================================================================
[docs]def lorentz_berthelot(g, suffix=""): import dgl g.multi_update_all( { "n1_as_%s_in_%s" % (pos_idx, term): ( dgl.function.copy_src( src="epsilon%s" % suffix, out="m_epsilon" ), geometric_mean(msg="m_epsilon", out="epsilon%s" % suffix), ) for pos_idx in [0, 1] for term in ["nonbonded", "onefour"] }, cross_reducer="sum", ) g.multi_update_all( { "n1_as_%s_in_%s" % (pos_idx, term): ( dgl.function.copy_src(src="sigma%s" % suffix, out="m_sigma"), arithmetic_mean(msg="m_sigma", out="sigma%s" % suffix), ) for pos_idx in [0, 1] for term in ["nonbonded", "onefour"] }, cross_reducer="sum", ) return g
[docs]def multiply_charges(g, suffix=""): """ Multiply the charges of atoms into nonbonded and onefour terms. Parameters ---------- g : dgl.HeteroGraph Input graph. Returns ------- dgl.HeteroGraph : The modified graph with charges. """ import dgl g.multi_update_all( { "n1_as_%s_in_%s" % (pos_idx, term): ( dgl.function.copy_src(src="q%s" % suffix, out="m_q"), dgl.function.sum(msg="m_q", out="_q") # lambda node: {"q%s" % suffix: node.mailbox["m_q"].prod(dim=1)} ) for pos_idx in [0, 1] for term in ["nonbonded", "onefour"] }, cross_reducer="stack", apply_node_func=lambda node: {"q": node.data["_q"].prod(dim=1)} ) return g
# ============================================================================= # ENERGY FUNCTIONS # =============================================================================
[docs]def lj_12_6(x, sigma, epsilon): """Lennard-Jones 12-6. Parameters ---------- x : `torch.Tensor`, `shape=(batch_size, 1)` or `(batch_size, batch_size, 1)` sigma : `torch.Tensor`, `shape=(batch_size, 1)` or `(batch_size, batch_size, 1)` epsilon : `torch.Tensor`, `shape=(batch_size, 1)` or `(batch_size, batch_size, 1)` Returns ------- u : `torch.Tensor`, `shape=(batch_size, 1)` or `(batch_size, batch_size, 1)` """ return esp.mm.functional.lj(x=x, sigma=sigma, epsilon=epsilon)
[docs]def lj_9_6(x, sigma, epsilon): """Lennard-Jones 9-6. Parameters ---------- x : `torch.Tensor`, `shape=(batch_size, 1)` or `(batch_size, batch_size, 1)` sigma : `torch.Tensor`, `shape=(batch_size, 1)` or `(batch_size, batch_size, 1)` epsilon : `torch.Tensor`, `shape=(batch_size, 1)` or `(batch_size, batch_size, 1)` Returns ------- u : `torch.Tensor`, `shape=(batch_size, 1)` or `(batch_size, batch_size, 1)` """ return esp.mm.functional.lj( x=x, sigma=sigma, epsilon=epsilon, order=[9, 6], coefficients=[2, 3] )
[docs]def coulomb(x, q, k_e=K_E): """ Columb interaction without cutoff. Parameters ---------- x : `torch.Tensor`, shape=`(batch_size, 1)` or `(batch_size, batch_size, 1)` Distance between atoms. q : `torch.Tensor`, `shape=(batch_size, 1) or `(batch_size, batch_size, 1)` Product of charge. Returns ------- torch.Tensor : `shape=(batch_size, 1)` or `(batch_size, batch_size, 1)` Coulomb energy. Notes ----- This computes half Coulomb energy to count for the duplication in onefour and nonbonded enumerations. """ return 0.5 * k_e * q / x