Source code for espaloma.mm.functional

# =============================================================================
# IMPORTS
# =============================================================================
import math
import torch
import espaloma as esp

# =============================================================================
# CONSTANTS
# =============================================================================
from simtk import unit
from simtk.unit import Quantity

LJ_SWITCH = Quantity(1.0, unit.angstrom).value_in_unit(
    esp.units.DISTANCE_UNIT
)

# =============================================================================
# UTILITY FUNCTIONS
# =============================================================================
[docs]def linear_mixture_to_original(k1, k2, b1, b2): """Translating linear mixture coefficients back to original parameterization. """ # (batch_size, ) k = k1 + k2 # (batch_size, ) b = (k1 * b1 + k2 * b2) / (k + 1e-7) return k, b
# ============================================================================= # MODULE FUNCTIONS # =============================================================================
[docs]def harmonic(x, k, eq, order=[2]): """Harmonic term. Parameters ---------- x : `torch.Tensor`, `shape=(batch_size, 1)` k : `torch.Tensor`, `shape=(batch_size, len(order))` eq : `torch.Tensor`, `shape=(batch_size, len(order))` order : `int` or `List` of `int` Returns ------- u : `torch.Tensor`, `shape=(batch_size, 1)` """ if isinstance(order, list): order = torch.tensor(order, device=x.device) return ( 0.5 * k * ((x - eq)).pow(order[:, None, None]).permute(1, 2, 0).sum(dim=-1) )
[docs]def periodic_fixed_phases( dihedrals: torch.Tensor, ks: torch.Tensor ) -> torch.Tensor: """Periodic torsion term with n_phases = 6, periodicities = 1..n_phases, phases = zeros Parameters ---------- dihedrals : torch.Tensor, shape=(n_snapshots, n_dihedrals) dihedral angles -- TODO: confirm in radians? ks : torch.Tensor, shape=(n_dihedrals, n_phases) force constants -- TODO: confirm in esp.unit.ENERGY_UNIT ? Returns ------- u : torch.Tensor, shape=(n_snapshots, 1) potential energy of each snapshot Notes ----- TODO: is there a way to annotate / type-hint tensor shapes? (currently adding many assert statements) TODO: merge with esp.mm.functional.periodic -- adding this because I was having difficulty debugging runtime tensor shape errors in esp.mm.functional.periodic, which allows for a more flexible mix of input shapes and types """ # periodicity = 1..n_phases n_phases = 6 periodicity = torch.arange(n_phases) + 1 # assert input shape consistency n_snapshots, n_dihedrals = dihedrals.shape n_dihedrals_, n_phases_ = ks.shape assert n_dihedrals == n_dihedrals_ assert n_phases == n_phases_ # promote everything to this shape stacked_shape = (n_snapshots, n_dihedrals, n_phases) # duplicate ks n_snapshots times ks_stacked = torch.stack([ks] * n_snapshots, dim=0) assert ks_stacked.shape == stacked_shape # duplicate dihedral angles n_phases times dihedrals_stacked = torch.stack([dihedrals] * n_phases, dim=2) assert dihedrals_stacked.shape == stacked_shape # duplicate periodicity n_snapshots * n_dihedrals times ns = torch.stack( [torch.stack([periodicity] * n_snapshots)] * n_dihedrals, dim=1 ) assert ns.shape == stacked_shape # compute k_n * cos(n * theta) for n in 1..n_phases, for each dihedral in each snapshot energy_terms = ks_stacked * torch.cos(ns * dihedrals_stacked) assert energy_terms.shape == stacked_shape # sum over n_dihedrals and n_phases energy_sums = energy_terms.sum(dim=(1, 2)) assert energy_sums.shape == (n_snapshots,) return energy_sums.reshape((n_snapshots, 1))
[docs]def periodic( x, k, periodicity=list(range(1, 7)), phases=[0.0 for _ in range(6)] ): """Periodic term. Parameters ---------- x : `torch.Tensor`, `shape=(batch_size, 1)` k : `torch.Tensor`, `shape=(batch_size, number_of_phases)` periodicity: either list of length number_of_phases, or `torch.Tensor`, `shape=(batch_size, number_of_phases)` phases : either list of length number_of_phases, or `torch.Tensor`, `shape=(batch_size, number_of_phases)` """ if isinstance(phases, list): phases = torch.tensor(phases, device=x.device) if isinstance(periodicity, list): periodicity = torch.tensor( periodicity, device=x.device, dtype=torch.get_default_dtype(), ) if periodicity.ndim == 1: periodicity = periodicity[None, None, :].repeat( x.shape[0], x.shape[1], 1 ) elif periodicity.ndim == 2: periodicity = periodicity[:, None, :].repeat(1, x.shape[1], 1) if phases.ndim == 1: phases = phases[None, None, :].repeat( x.shape[0], x.shape[1], 1, ) elif phases.ndim == 2: phases = phases[:, None, :].repeat( 1, x.shape[1], 1, ) n_theta = periodicity * x[:, :, None] n_theta_minus_phases = n_theta - phases cos_n_theta_minus_phases = n_theta_minus_phases.cos() k = k[:, None, :].repeat(1, x.shape[1], 1) # energy = (k * (1.0 + cos_n_theta_minus_phases)).sum(dim=-1) energy = ( torch.nn.functional.relu(k) * (cos_n_theta_minus_phases + 1.0) - torch.nn.functional.relu(0.0 - k) * (cos_n_theta_minus_phases - 1.0) ).sum(dim=-1) return energy
# simple implementation # def harmonic(x, k, eq): # return k * (x - eq) ** 2 # # def harmonic_re(x, k, eq, a=0.0, b=0.3): # # temporary # ka = k # kb = eq # # c = ((ka * a + kb * b) / (ka + kb)) ** 2 - a ** 2 - b ** 2 # # return ka * (x - a) ** 2 + kb * (x - b) ** 2
[docs]def lj( x, epsilon, sigma, order=[12, 6], coefficients=[1.0, 1.0], switch=LJ_SWITCH, ): r"""Lennard-Jones term. Notes ----- ..math:: E = \epsilon ((\sigma / r) ^ {12} - (\sigma / r) ^ 6) Parameters ---------- x : `torch.Tensor`, `shape=(batch_size, 1)` epsilon : `torch.Tensor`, `shape=(batch_size, len(order))` sigma : `torch.Tensor`, `shape=(batch_size, len(order))` order : `int` or `List` of `int` coefficients : torch.tensor or list switch : unitless switch width (distance) Returns ------- u : `torch.Tensor`, `shape=(batch_size, 1)` """ if isinstance(order, list): order = torch.tensor(order, device=x.device) if isinstance(coefficients, list): coefficients = torch.tensor(coefficients, device=x.device) assert order.shape[0] == 2 assert order.dim() == 1 # TODO: # for experiments only # erase later # compute sigma over x sigma_over_x = sigma / x # erase values under switch sigma_over_x = torch.where( torch.lt(x, switch), torch.zeros_like(sigma_over_x), sigma_over_x, ) return epsilon * ( coefficients[0] * sigma_over_x ** order[0] - coefficients[1] * sigma_over_x ** order[1] )
[docs]def gaussian(x, coefficients, phases=[idx * 0.001 for idx in range(200)]): r"""Gaussian basis function. Parameters ---------- x : torch.Tensor coefficients : list or torch.Tensor of length n_phases phases : list or torch.Tensor of length n_phases """ if isinstance(phases, list): # (number_of_phases, ) phases = torch.tensor(phases, device=x.device) # broadcasting # (number_of_hypernodes, number_of_snapshots, number_of_phases) phases = phases[None, None, :].repeat(x.shape[0], x.shape[1], 1) x = x[:, :, None].repeat(1, 1, phases.shape[-1]) coefficients = coefficients[:, None, :].repeat(1, x.shape[1], 1) return (coefficients * torch.exp(-0.5 * (x - phases) ** 2)).sum(-1)
[docs]def linear_mixture(x, coefficients, phases=[0.0, 1.0]): r"""Linear mixture basis function. x : torch.Tensor coefficients : list or torch.Tensor of length 2 phases : list of length 2 """ assert len(phases) == 2, "Only two phases now." assert coefficients.shape[-1] == 2 # partition the dimensions # (, ) b1 = phases[0] b2 = phases[1] # (batch_size, 1) k1 = coefficients[:, 0][:, None] k2 = coefficients[:, 1][:, None] # get the original parameters # (batch_size, ) # k, b = linear_mixture_to_original(k1, k2, b1, b2) # (batch_size, 1) u1 = k1 * (x - b1) ** 2 u2 = k2 * (x - b2) ** 2 u = 0.5 * (u1 + u2) # - k1 * b1 ** 2 - k2 ** b2 ** 2 + b ** 2 return u
[docs]def harmonic_periodic_coupled( x_harmonic, x_periodic, k, eq, periodicity=list(range(1, 3)), ): if isinstance(periodicity, list): periodicity = torch.tensor( periodicity, device=x_harmonic.device, dtype=torch.get_default_dtype(), ) n_theta = ( periodicity[None, None, :].repeat( x_periodic.shape[0], x_periodic.shape[1], 1 ) * x_periodic[:, :, None] ) cos_n_theta = n_theta.cos() k = k[:, None, :].repeat(1, x_periodic.shape[1], 1) sum_k_cos_n_theta = (k * cos_n_theta).sum(dim=-1) x_minus_eq = x_harmonic - eq energy = x_minus_eq * sum_k_cos_n_theta return energy
[docs]def harmonic_harmonic_coupled( x0, x1, eq0, eq1, k, ): energy = k * (x0 - eq0) * (x1 - eq1) return energy
[docs]def harmonic_harmonic_periodic_coupled( theta0, theta1, eq0, eq1, phi, k, ): energy = k * (theta0 - eq0) * (theta1 - eq1) * phi.cos() return energy