Source code for espaloma.nn.readout.charge_equilibrium

""" Charge equilibrium.ß

"""
# =============================================================================
# IMPORTS
# =============================================================================
import torch

# =============================================================================
# UTILITY FUNCTIONS
# =============================================================================
[docs]def get_charges(node): """ Solve the function to get the absolute charges of atoms in a molecule from parameters. Parameters ---------- e : tf.Tensor, dtype = tf.float32, electronegativity. s : tf.Tensor, dtype = tf.float32, hardness. Q : tf.Tensor, dtype = tf.float32, shape=(), total charge of a molecule. We use Lagrange multipliers to analytically give the solution. $$ U({\bf q}) &= \sum_{i=1}^N \left[ e_i q_i + \frac{1}{2} s_i q_i^2\right] - \lambda \, \left( \sum_{j=1}^N q_j - Q \right) \\ &= \sum_{i=1}^N \left[ (e_i - \lambda) q_i + \frac{1}{2} s_i q_i^2 \right ] + Q $$ This gives us: $$ q_i^* &= - e_i s_i^{-1} + \lambda s_i^{-1} \\ &= - e_i s_i^{-1} + s_i^{-1} \frac{ Q + \sum\limits_{i=1}^N e_i \, s_i^{-1} }{\sum\limits_{j=1}^N s_j^{-1}} $$ """ e = node.data["e"] s = node.data["s"] sum_e_s_inv = node.data["sum_e_s_inv"] sum_s_inv = node.data["sum_s_inv"] sum_q = node.data["sum_q"] return { "q": -e * s ** -1 + (s ** -1) * torch.div(sum_q + sum_e_s_inv, sum_s_inv) }
# ============================================================================= # MODULE CLASS # =============================================================================
[docs]class ChargeEquilibrium(torch.nn.Module): """Charge equilibrium within batches of molecules."""
[docs] def __init__(self): super(ChargeEquilibrium, self).__init__()
[docs] def forward(self, g, total_charge=0.0): """ apply charge equilibrium to all molecules in batch """ # calculate $s ^ {-1}$ and $ es ^ {-1}$ import dgl g.apply_nodes( lambda node: {"s_inv": node.data["s"] ** -1}, ntype="n1" ) g.apply_nodes( lambda node: {"e_s_inv": node.data["e"] * node.data["s"] ** -1}, ntype="n1", ) if "q_ref" in g.nodes["n1"].data: # get total charge g.update_all( dgl.function.copy_src(src="q_ref", out="m_q"), dgl.function.sum(msg="m_q", out="sum_q"), etype="n1_in_g", ) else: g.nodes["g"].data["sum_q"] = ( torch.ones( g.batch_size, 1, device=g.nodes["n1"].data["s"].device, ) * total_charge ) g.update_all( dgl.function.copy_src(src="sum_q", out="m_sum_q"), dgl.function.sum(msg="m_sum_q", out="sum_q"), etype="g_has_n1", ) # get the sum of $s^{-1}$ and $m_s^{-1}$ g.update_all( dgl.function.copy_src(src="s_inv", out="m_s_inv"), dgl.function.sum(msg="m_s_inv", out="sum_s_inv"), etype="n1_in_g", ) g.update_all( dgl.function.copy_src(src="e_s_inv", out="m_e_s_inv"), dgl.function.sum(msg="m_e_s_inv", out="sum_e_s_inv"), etype="n1_in_g", ) g.update_all( dgl.function.copy_src(src="sum_s_inv", out="m_sum_s_inv"), dgl.function.sum(msg="m_sum_s_inv", out="sum_s_inv"), etype="g_has_n1", ) g.update_all( dgl.function.copy_src(src="sum_e_s_inv", out="m_sum_e_s_inv"), dgl.function.sum(msg="m_sum_e_s_inv", out="sum_e_s_inv"), etype="g_has_n1", ) g.apply_nodes(get_charges, ntype="n1") return g