Source code for espaloma.graphs.legacy_force_field

# =============================================================================
# IMPORTS
# =============================================================================
import rdkit
import torch
from openff.toolkit.topology import Molecule
import espaloma as esp

from openmmforcefields.generators import SystemGenerator
from simtk import openmm, unit
from simtk.openmm.app import Simulation
from simtk.unit import Quantity

# =============================================================================
# CONSTANTS
# =============================================================================
REDUNDANT_TYPES = {
    "cd": "cc",
    "cf": "ce",
    "cq": "cp",
    "pd": "pc",
    "pf": "pe",
    "nd": "nc",
}

# simulation specs
TEMPERATURE = 350 * unit.kelvin
STEP_SIZE = 1.0 * unit.femtosecond
COLLISION_RATE = 1.0 / unit.picosecond
EPSILON_MIN = 0.05 * unit.kilojoules_per_mole

# =============================================================================
# MODULE CLASSES
# =============================================================================
[docs]class LegacyForceField: """Class to hold legacy forcefield for typing and parameter assignment. Parameters ---------- forcefield : string name and version of the forcefield. Methods ------- parametrize() Parametrize a molecular system. typing() Provide legacy typing for a molecular system. """
[docs] def __init__(self, forcefield="gaff-1.81"): self.forcefield = forcefield self._prepare_forcefield()
@staticmethod def _convert_to_off(mol): import openff.toolkit if isinstance(mol, esp.Graph): return mol.mol elif isinstance(mol, openff.toolkit.topology.molecule.Molecule): return mol elif isinstance(mol, rdkit.Chem.rdchem.Mol): return Molecule.from_rdkit(mol) elif "openeye" in str( type(mol) ): # because we don't want to depend on OE return Molecule.from_openeye(mol) def _prepare_forcefield(self): if "gaff" in self.forcefield: self._prepare_gaff() elif "smirnoff" in self.forcefield: # do nothing for now self._prepare_smirnoff() elif "openff" in self.forcefield: self._prepare_openff() else: raise NotImplementedError def _prepare_openff(self): from openff.toolkit.typing.engines.smirnoff import ForceField self.FF = ForceField("%s.offxml" % self.forcefield) def _prepare_smirnoff(self): from openff.toolkit.typing.engines.smirnoff import ForceField self.FF = ForceField("%s.offxml" % self.forcefield) def _prepare_gaff(self): import os import xml.etree.ElementTree as ET import openmmforcefields # get the openff.toolkits path openmmforcefields_path = os.path.dirname(openmmforcefields.__file__) # get the xml path ffxml_path = ( openmmforcefields_path + "/ffxml/amber/gaff/ffxml/" + self.forcefield + ".xml" ) # parse xml tree = ET.parse(ffxml_path) root = tree.getroot() nonbonded = list(root)[-1] atom_types = [atom.get("type") for atom in nonbonded.findall("Atom")] # remove redundant types [atom_types.remove(bad_type) for bad_type in REDUNDANT_TYPES.keys()] # compose the translation dictionaries str_2_idx = dict(zip(atom_types, range(len(atom_types)))) idx_2_str = dict(zip(range(len(atom_types)), atom_types)) # provide mapping for redundant types for bad_type, good_type in REDUNDANT_TYPES.items(): str_2_idx[bad_type] = str_2_idx[good_type] # make translation dictionaries attributes of self self._str_2_idx = str_2_idx self._idx_2_str = idx_2_str def _type_gaff(self, g): """Type a molecular graph using gaff force fields.""" # assert the forcefield is indeed of gaff family assert "gaff" in self.forcefield # make sure mol is in openff.toolkit format ` mol = g.mol # import template generator from openmmforcefields.generators import GAFFTemplateGenerator gaff = GAFFTemplateGenerator( molecules=mol, forcefield=self.forcefield ) # create temporary directory for running antechamber import os import shutil import tempfile tempdir = tempfile.mkdtemp() prefix = "molecule" input_sdf_filename = os.path.join(tempdir, prefix + ".sdf") gaff_mol2_filename = os.path.join(tempdir, prefix + ".gaff.mol2") frcmod_filename = os.path.join(tempdir, prefix + ".frcmod") # write sdf for input mol.to_file(input_sdf_filename, file_format="sdf") # run antechamber gaff._run_antechamber( molecule_filename=input_sdf_filename, input_format="mdl", gaff_mol2_filename=gaff_mol2_filename, frcmod_filename=frcmod_filename, ) gaff._read_gaff_atom_types_from_mol2(gaff_mol2_filename, mol) gaff_types = [atom.gaff_type for atom in mol.atoms] shutil.rmtree(tempdir) # put types into graph object if g is None: g = esp.Graph(mol) g.nodes["n1"].data["legacy_typing"] = torch.tensor( [self._str_2_idx[atom] for atom in gaff_types] ) return g def _parametrize_gaff(self, g, n_max_phases=6): from openmmforcefields.generators import SystemGenerator # define a system generator system_generator = SystemGenerator( small_molecule_forcefield=self.forcefield, ) mol = g.mol # mol.assign_partial_charges("formal_charge") # create system sys = system_generator.create_system( topology=mol.to_topology().to_openmm(), molecules=mol, ) bond_lookup = { tuple(idxs.detach().numpy()): position for position, idxs in enumerate(g.nodes["n2"].data["idxs"]) } angle_lookup = { tuple(idxs.detach().numpy()): position for position, idxs in enumerate(g.nodes["n3"].data["idxs"]) } torsion_lookup = { tuple(idxs.detach().numpy()): position for position, idxs in enumerate(g.nodes["n4"].data["idxs"]) } improper_lookup = { tuple(idxs.detach().numpy()): position for position, idxs in enumerate( g.nodes["n4_improper"].data["idxs"] ) } torsion_phases = torch.zeros( g.heterograph.number_of_nodes("n4"), n_max_phases, ) torsion_periodicities = torch.zeros( g.heterograph.number_of_nodes("n4"), n_max_phases, ) torsion_ks = torch.zeros( g.heterograph.number_of_nodes("n4"), n_max_phases, ) improper_phases = torch.zeros( g.heterograph.number_of_nodes("n4"), n_max_phases, ) improper_periodicities = torch.zeros( g.heterograph.number_of_nodes("n4"), n_max_phases, ) improper_ks = torch.zeros( g.heterograph.number_of_nodes("n4"), n_max_phases, ) for force in sys.getForces(): name = force.__class__.__name__ if "HarmonicBondForce" in name: assert ( force.getNumBonds() * 2 == g.heterograph.number_of_nodes("n2") ) g.nodes["n2"].data["eq_ref"] = torch.zeros( force.getNumBonds() * 2, 1 ) g.nodes["n2"].data["k_ref"] = torch.zeros( force.getNumBonds() * 2, 1 ) for idx in range(force.getNumBonds()): idx0, idx1, eq, k = force.getBondParameters(idx) position = bond_lookup[(idx0, idx1)] g.nodes["n2"].data["eq_ref"][position] = eq.value_in_unit( esp.units.DISTANCE_UNIT, ) g.nodes["n2"].data["k_ref"][position] = k.value_in_unit( esp.units.FORCE_CONSTANT_UNIT, ) position = bond_lookup[(idx1, idx0)] g.nodes["n2"].data["eq_ref"][position] = eq.value_in_unit( esp.units.DISTANCE_UNIT, ) g.nodes["n2"].data["k_ref"][position] = k.value_in_unit( esp.units.FORCE_CONSTANT_UNIT, ) if "HarmonicAngleForce" in name: assert ( force.getNumAngles() * 2 == g.heterograph.number_of_nodes("n3") ) g.nodes["n3"].data["eq_ref"] = torch.zeros( force.getNumAngles() * 2, 1 ) g.nodes["n3"].data["k_ref"] = torch.zeros( force.getNumAngles() * 2, 1 ) for idx in range(force.getNumAngles()): idx0, idx1, idx2, eq, k = force.getAngleParameters(idx) position = angle_lookup[(idx0, idx1, idx2)] g.nodes["n3"].data["eq_ref"][position] = eq.value_in_unit( esp.units.ANGLE_UNIT, ) g.nodes["n3"].data["k_ref"][position] = k.value_in_unit( esp.units.ANGLE_FORCE_CONSTANT_UNIT, ) position = angle_lookup[(idx2, idx1, idx0)] g.nodes["n3"].data["eq_ref"][position] = eq.value_in_unit( esp.units.ANGLE_UNIT, ) g.nodes["n3"].data["k_ref"][position] = k.value_in_unit( esp.units.ANGLE_FORCE_CONSTANT_UNIT, ) if "PeriodicTorsionForce" in name: for idx in range(force.getNumTorsions()): ( idx0, idx1, idx2, idx3, periodicity, phase, k, ) = force.getTorsionParameters(idx) if (idx0, idx1, idx2, idx3) in torsion_lookup: position = torsion_lookup[(idx0, idx1, idx2, idx3)] for sub_idx in range(n_max_phases): if torsion_ks[position, sub_idx] == 0: torsion_ks[ position, sub_idx ] = 0.5 * k.value_in_unit( esp.units.ENERGY_UNIT ) torsion_phases[ position, sub_idx ] = phase.value_in_unit(esp.units.ANGLE_UNIT) torsion_periodicities[ position, sub_idx ] = periodicity position = torsion_lookup[ (idx3, idx2, idx1, idx0) ] torsion_ks[ position, sub_idx ] = 0.5 * k.value_in_unit( esp.units.ENERGY_UNIT ) torsion_phases[ position, sub_idx ] = phase.value_in_unit(esp.units.ANGLE_UNIT) torsion_periodicities[ position, sub_idx ] = periodicity break g.heterograph.apply_nodes( lambda nodes: { "k_ref": torsion_ks, "periodicity_ref": torsion_periodicities, "phases_ref": torsion_phases, }, ntype="n4", ) """ g.heterograph.apply_nodes( lambda nodes: { "k_ref": improper_ks, "periodicity_ref": improper_periodicities, "phases_ref": improper_phases, }, ntype="n4_improper" ) """ """ def apply_torsion(node, n_max_phases=6): phases = torch.zeros( g.heterograph.number_of_nodes("n4"), n_max_phases, ) periodicity = torch.zeros( g.heterograph.number_of_nodes("n4"), n_max_phases, ) k = torch.zeros(g.heterograph.number_of_nodes("n4"), n_max_phases,) for idx in range(g.heterograph.number_of_nodes("n4")): idxs = tuple(node.data["idxs"][idx].numpy()) if idxs in force: _force = force[idxs] for sub_idx in range(len(_force.periodicity)): if hasattr(_force, "k%s" % sub_idx): k[idx, sub_idx] = getattr( _force, "k%s" % sub_idx ).value_in_unit(esp.units.ENERGY_UNIT) phases[idx, sub_idx] = getattr( _force, "phase%s" % sub_idx ).value_in_unit(esp.units.ANGLE_UNIT) periodicity[idx, sub_idx] = getattr( _force, "periodicity%s" % sub_idx ) return { "k_ref": k, "periodicity_ref": periodicity, "phases_ref": phases, } g.heterograph.apply_nodes(apply_torsion, ntype="n4") """ return g def _parametrize_smirnoff(self, g): # mol = self._convert_to_off(mol) forces = self.FF.label_molecules(g.mol.to_topology())[0] g.heterograph.apply_nodes( lambda node: { "k_ref": 2.0 * torch.Tensor( [ forces["Bonds"][ tuple(node.data["idxs"][idx].numpy()) ].k.value_in_unit(esp.units.FORCE_CONSTANT_UNIT) for idx in range(node.data["idxs"].shape[0]) ] )[:, None] }, ntype="n2", ) g.heterograph.apply_nodes( lambda node: { "eq_ref": torch.Tensor( [ forces["Bonds"][ tuple(node.data["idxs"][idx].numpy()) ].length.value_in_unit(esp.units.DISTANCE_UNIT) for idx in range(node.data["idxs"].shape[0]) ] )[:, None] }, ntype="n2", ) g.heterograph.apply_nodes( lambda node: { "k_ref": 2.0 * torch.Tensor( # OpenFF records 1/2k as param [ forces["Angles"][ tuple(node.data["idxs"][idx].numpy()) ].k.value_in_unit(esp.units.ANGLE_FORCE_CONSTANT_UNIT) for idx in range(node.data["idxs"].shape[0]) ] )[:, None] }, ntype="n3", ) g.heterograph.apply_nodes( lambda node: { "eq_ref": torch.Tensor( [ forces["Angles"][ tuple(node.data["idxs"][idx].numpy()) ].angle.value_in_unit(esp.units.ANGLE_UNIT) for idx in range(node.data["idxs"].shape[0]) ] )[:, None] }, ntype="n3", ) g.heterograph.apply_nodes( lambda node: { "epsilon_ref": torch.Tensor( [ forces["vdW"][(idx,)].epsilon.value_in_unit( esp.units.ENERGY_UNIT ) for idx in range(g.heterograph.number_of_nodes("n1")) ] )[:, None] }, ntype="n1", ) g.heterograph.apply_nodes( lambda node: { "sigma_ref": torch.Tensor( [ forces["vdW"][(idx,)].rmin_half.value_in_unit( esp.units.DISTANCE_UNIT ) for idx in range(g.heterograph.number_of_nodes("n1")) ] )[:, None] }, ntype="n1", ) def apply_torsion(node, n_max_phases=6): phases = torch.zeros( g.heterograph.number_of_nodes("n4"), n_max_phases, ) periodicity = torch.zeros( g.heterograph.number_of_nodes("n4"), n_max_phases, ) k = torch.zeros( g.heterograph.number_of_nodes("n4"), n_max_phases, ) force = forces["ProperTorsions"] for idx in range(g.heterograph.number_of_nodes("n4")): idxs = tuple(node.data["idxs"][idx].numpy()) if idxs in force: _force = force[idxs] for sub_idx in range(len(_force.periodicity)): if hasattr(_force, "k%s" % sub_idx): k[idx, sub_idx] = getattr( _force, "k%s" % sub_idx ).value_in_unit(esp.units.ENERGY_UNIT) phases[idx, sub_idx] = getattr( _force, "phase%s" % sub_idx ).value_in_unit(esp.units.ANGLE_UNIT) periodicity[idx, sub_idx] = getattr( _force, "periodicity%s" % sub_idx ) return { "k_ref": k, "periodicity_ref": periodicity, "phases_ref": phases, } def apply_improper_torsion(node, n_max_phases=6): phases = torch.zeros( g.heterograph.number_of_nodes("n4_improper"), n_max_phases, ) periodicity = torch.zeros( g.heterograph.number_of_nodes("n4_improper"), n_max_phases, ) k = torch.zeros( g.heterograph.number_of_nodes("n4_improper"), n_max_phases, ) force = forces["ImproperTorsions"] for idx in range(g.heterograph.number_of_nodes("n4_improper")): idxs = tuple(node.data["idxs"][idx].numpy()) if idxs in force: _force = force[idxs] for sub_idx in range(len(_force.periodicity)): if hasattr(_force, "k%s" % sub_idx): k[idx, sub_idx] = getattr( _force, "k%s" % sub_idx ).value_in_unit(esp.units.ENERGY_UNIT) phases[idx, sub_idx] = getattr( _force, "phase%s" % sub_idx ).value_in_unit(esp.units.ANGLE_UNIT) periodicity[idx, sub_idx] = getattr( _force, "periodicity%s" % sub_idx ) return { "k_ref": k, "periodicity_ref": periodicity, "phases_ref": phases, } g.heterograph.apply_nodes(apply_torsion, ntype="n4") g.heterograph.apply_nodes(apply_improper_torsion, ntype="n4_improper") return g def baseline_energy(self, g, suffix=None): if suffix is None: suffix = "_" + self.forcefield from openmmforcefields.generators import SystemGenerator # define a system generator system_generator = SystemGenerator( small_molecule_forcefield=self.forcefield, ) mol = g.mol # mol.assign_partial_charges("formal_charge") # create system system = system_generator.create_system( topology=mol.to_topology().to_openmm(), molecules=mol, ) # parameterize topology topology = g.mol.to_topology().to_openmm() integrator = openmm.LangevinIntegrator( TEMPERATURE, COLLISION_RATE, STEP_SIZE ) # create simulation simulation = Simulation( topology=topology, system=system, integrator=integrator ) us = [] xs = ( Quantity( g.nodes["n1"].data["xyz"].detach().numpy(), esp.units.DISTANCE_UNIT, ) .value_in_unit(unit.nanometer) .transpose((1, 0, 2)) ) for x in xs: simulation.context.setPositions(x) us.append( simulation.context.getState(getEnergy=True) .getPotentialEnergy() .value_in_unit(esp.units.ENERGY_UNIT) ) g.nodes["g"].data["u%s" % suffix] = torch.tensor(us)[None, :] return g def _multi_typing_smirnoff(self, g): # mol = self._convert_to_off(mol) forces = self.FF.label_molecules(g.mol.to_topology())[0] g.heterograph.apply_nodes( lambda node: { "legacy_typing": torch.Tensor( [ int( forces["Bonds"][ tuple(node.data["idxs"][idx].numpy()) ].id[1:] ) for idx in range(node.data["idxs"].shape[0]) ] ).long() }, ntype="n2", ) g.heterograph.apply_nodes( lambda node: { "legacy_typing": torch.Tensor( [ int( forces["Angles"][ tuple(node.data["idxs"][idx].numpy()) ].id[1:] ) for idx in range(node.data["idxs"].shape[0]) ] ).long() }, ntype="n3", ) g.heterograph.apply_nodes( lambda node: { "legacy_typing": torch.Tensor( [ int(forces["vdW"][(idx,)].id[1:]) for idx in range(g.heterograph.number_of_nodes("n1")) ] ).long() }, ntype="n1", ) return g
[docs] def parametrize(self, g): """Parametrize a molecular graph.""" if "smirnoff" in self.forcefield or "openff" in self.forcefield: return self._parametrize_smirnoff(g) elif "gaff" in self.forcefield: return self._parametrize_gaff(g) else: raise NotImplementedError
[docs] def typing(self, g): """Type a molecular graph.""" if "gaff" in self.forcefield: return self._type_gaff(g) else: raise NotImplementedError
[docs] def multi_typing(self, g): """ Type a molecular graph for hetero nodes. """ if "smirnoff" in self.forcefield: return self._multi_typing_smirnoff(g) else: raise NotImplementedError
def __call__(self, *args, **kwargs): return self.typing(*args, **kwargs)