Source code for espaloma.data.utils

# =============================================================================
# IMPORTS
# =============================================================================
import random

import numpy as np
import pandas as pd
import torch
import contextlib

import espaloma as esp

OFFSETS = {
    1: -0.500607632585,
    6: -37.8302333826,
    7: -54.5680045287,
    8: -75.0362229210,
}

# ==============================================================================
# UTILITY FUNCTIONS
# ==============================================================================


[docs]@contextlib.contextmanager def make_temp_directory(): import tempfile, shutil temp_dir = tempfile.mkdtemp() try: yield temp_dir finally: shutil.rmtree(temp_dir)
[docs]def sum_offsets(elements): return sum([OFFSETS[element] for element in elements])
[docs]def from_csv(path, toolkit="rdkit", smiles_col=-1, y_cols=[-2], seed=2666): """Read csv from file.""" def _from_csv(): df = pd.read_csv(path) df_smiles = df.iloc[:, smiles_col] df_y = df.iloc[:, y_cols] if toolkit == "rdkit": from rdkit import Chem mols = [Chem.MolFromSmiles(smiles) for smiles in df_smiles] gs = [esp.HomogeneousGraph(mol) for mol in mols] elif toolkit == "openeye": from openeye import oechem mols = [ oechem.OESmilesToMol(oechem.OEGraphMol(), smiles) for smiles in df_smiles ] gs = [esp.HomogeneousGraph(mol) for mol in mols] ds = list(zip(gs, list(torch.tensor(df_y.values)))) random.seed(seed) random.shuffle(ds) return ds return _from_csv
[docs]def normalize(ds): """Get mean and std.""" gs, ys = tuple(zip(*ds)) y_mean = np.mean(ys) y_std = np.std(ys) def norm(y): return (y - y_mean) / y_std def unnorm(y): return y * y_std + y_mean return y_mean, y_std, norm, unnorm
[docs]def split(ds, partition): """Split the dataset according to some partition.""" n_data = len(ds) # get the actual size of partition partition = [int(n_data * x / sum(partition)) for x in partition] ds_batched = [] idx = 0 for p_size in partition: ds_batched.append(ds[idx : idx + p_size]) idx += p_size return ds_batched
[docs]def batch(ds, batch_size, seed=2666): """Batch graphs and values after shuffling.""" import dgl # get the numebr of data n_data_points = len(ds) n_batches = n_data_points // batch_size # drop the rest random.seed(seed) random.shuffle(ds) gs, ys = tuple(zip(*ds)) gs_batched = [ dgl.batch(gs[idx * batch_size : (idx + 1) * batch_size]) for idx in range(n_batches) ] ys_batched = [ torch.stack(ys[idx * batch_size : (idx + 1) * batch_size], dim=0) for idx in range(n_batches) ] return list(zip(gs_batched, ys_batched))
[docs]def collate_fn(graphs): import dgl return esp.HomogeneousGraph(dgl.batch(graphs))
[docs]def infer_mol_from_coordinates( coordinates, species, smiles_ref=None, coordinates_unit="angstrom", ): # local import from openeye import oechem from simtk import unit from simtk.unit import Quantity if isinstance(coordinates_unit, str): coordinates_unit = getattr(unit, coordinates_unit) # make sure we have the coordinates # in the unit system coordinates = Quantity(coordinates, coordinates_unit).value_in_unit( unit.angstrom # to make openeye happy ) # initialize molecule mol = oechem.OEGraphMol() if all(isinstance(symbol, str) for symbol in species): [ mol.NewAtom(getattr(oechem, "OEElemNo_" + symbol)) for symbol in species ] elif all(isinstance(symbol, int) for symbol in species): [ mol.NewAtom( getattr( oechem, "OEElemNo_" + oechem.OEGetAtomicSymbol(symbol) ) ) for symbol in species ] else: raise RuntimeError( "The species can only be all strings or all integers." ) mol.SetCoords(coordinates.reshape([-1])) mol.SetDimension(3) oechem.OEDetermineConnectivity(mol) oechem.OEFindRingAtomsAndBonds(mol) oechem.OEPerceiveBondOrders(mol) if smiles_ref is not None: smiles_can = oechem.OECreateCanSmiString(mol) ims = oechem.oemolistream() ims.SetFormat(oechem.OEFormat_SMI) ims.openstring(smiles_ref) mol_ref = next(ims.GetOEMols()) smiles_ref = oechem.OECreateCanSmiString(mol_ref) assert ( smiles_ref == smiles_can ), "SMILES different. Input is %s, ref is %s" % ( smiles_can, smiles_ref, ) from openff.toolkit.topology import Molecule _mol = Molecule.from_openeye(mol, allow_undefined_stereo=True) g = esp.Graph(_mol) return g