Toy experiment: Molecular mechanics (MM) fitting on subsampled PhAlkEthOH dataset.

Open in Google Colab: http://data.wangyq.net/esp_notebooks/phalkethoh_mm_small.ipynb

This notebook is intended to recover the MM fitting behavior in https://arxiv.org/abs/2010.01196

To assess how well Espaloma can learn to reproduce an MM force field from a limited amount of data, we selected a chemical dataset of limited complexity—PhAlkEthOH—which consists of linear and cyclic molecules containing phenyl rings, small alkanes, ethers, and alcohols composed of only the elements carbon, oxygen, and hydrogen. We generated a set of conformational snapshots for each molecule using short high-temperature molecular dynamics simulations at 300~K initiated from multiple conformations to ensure adequate sampling of conformers. The AlkEthOH dataset was randomly partitioned (by molecules) into 80% training, 10% validation, and 10% test molecules, with 100 snapshots/molecule, and an Espaloma model was trained with early stopping via monitoring for a decrease in accuracy in the validation set.

https://pbs.twimg.com/media/FBL0qACXIBkJLQZ?format=png&name=4096x4096

Installation and imports

# install conda
! pip install -q condacolab
import condacolab
condacolab.install()
⏬ Downloading https://github.com/jaimergp/miniforge/releases/latest/download/Mambaforge-colab-Linux-x86_64.sh...
📦 Installing...
📌 Adjusting configuration...
🩹 Patching environment...
⏲ Done in 0:00:34
🔁 Restarting kernel...
%%capture
! mamba install --yes --strict-channel-priority --channel jaimergp/label/unsupported-cudatoolkit-shim --channel omnia --channel omnia/label/cuda100 --channel dglteam --channel numpy openmm openmmtools openmmforcefields rdkit openff-toolkit dgl-cuda10.0 qcportal
! git clone https://github.com/choderalab/espaloma.git
Cloning into 'espaloma'...
remote: Enumerating objects: 7812, done.
remote: Counting objects: 100% (3634/3634), done.
remote: Compressing objects: 100% (1649/1649), done.
remote: Total 7812 (delta 2714), reused 2639 (delta 1900), pack-reused 4178
Receiving objects: 100% (7812/7812), 13.50 MiB | 11.77 MiB/s, done.
Resolving deltas: 100% (5538/5538), done.
import torch
import sys
sys.path.append("/content/espaloma")
import espaloma as esp
Warning: importing 'simtk.openmm' is deprecated.  Import 'openmm' instead.
Warning: Unable to load toolkit 'OpenEye Toolkit'. The Open Force Field Toolkit does not require the OpenEye Toolkits, and can use RDKit/AmberTools instead. However, if you have a valid license for the OpenEye Toolkits, consider installing them for faster performance and additional file format support: https://docs.eyesopen.com/toolkits/python/quickstart-python/linuxosx.html OpenEye offers free Toolkit licenses for academics: https://www.eyesopen.com/academic-licensing

Load dataset

Here we load the PhAlKeThoh dataset and shuffle before splitting into training, validation, and test (80%:10%:10%)

%%capture
! wget http://data.wangyq.net/esp_dataset/phalkethoh_mm_small.zip
! unzip phalkethoh_mm_small.zip
ds = esp.data.dataset.GraphDataset.load("phalkethoh")
ds.shuffle(seed=2666)
ds_tr, ds_vl, ds_te = ds.split([8, 1, 1])
DGL backend not selected or invalid.  Assuming PyTorch for now.
Using backend: pytorch
Setting the default backend to "pytorch". You can change it in the ~/.dgl/config.json file or export the DGLBACKEND environment variable.  Valid options are: pytorch, mxnet, tensorflow (all lowercase)

A training dataloader is constructed with batch_size=100

ds_tr_loader = ds_tr.view(batch_size=100, shuffle=True)
g_tr = next(iter(ds_tr.view(batch_size=len(ds_tr))))
g_vl = next(iter(ds_vl.view(batch_size=len(ds_vl))))
/usr/local/lib/python3.7/site-packages/dgl/base.py:45: DGLWarning: From v0.5, DGLHeteroGraph is merged into DGLGraph. You can safely replace dgl.batch_hetero with dgl.batch
  return warnings.warn(message, category=category, stacklevel=1)

Define model

Define Espaloma stage I: graph -> atom latent representation

representation = esp.nn.Sequential(
    layer=esp.nn.layers.dgl_legacy.gn("SAGEConv"), # use SAGEConv implementation in DGL
    config=[128, "relu", 128, "relu", 128, "relu"], # 3 layers, 128 units, ReLU activation
)

Define Espaloma stage II and III: atom latent representation -> bond, angle, and torsion representation and parameters. And compose all three Espaloma stages into an end-to-end model.

readout = esp.nn.readout.janossy.JanossyPooling(
    in_features=128, config=[128, "relu", 128, "relu", 128, "relu"],
    out_features={              # define modular MM parameters Espaloma will assign
        1: {"e": 1, "s": 1}, # atom hardness and electronegativity
        2: {"log_coefficients": 2}, # bond linear combination, enforce positive
        3: {"log_coefficients": 2}, # angle linear combination, enforce positive
        4: {"k": 6}, # torsion barrier heights (can be positive or negative)
    },
)

espaloma_model = torch.nn.Sequential(
                 representation, readout, esp.nn.readout.janossy.ExpCoefficients(),
                 esp.mm.geometry.GeometryInGraph(),
                 esp.mm.energy.EnergyInGraph(),
                 esp.mm.energy.EnergyInGraph(suffix="_ref"),
                 esp.nn.readout.charge_equilibrium.ChargeEquilibrium(),
)
if torch.cuda.is_available():
    espaloma_model = espaloma_model.cuda()

Loss function is specified as the MSE between predicted and reference energy.

loss_fn = esp.metrics.GraphMetric(
        base_metric=torch.nn.MSELoss(), # use mean-squared error loss
        between=['u', "u_ref"],         # between predicted and QM energies
        level="g", # compare on graph level
)

Define optimizer

optimizer = torch.optim.Adam(espaloma_model.parameters(), 1e-4)

Train it!

for idx_epoch in range(10000):
    for g in ds_tr_loader:
        optimizer.zero_grad()
        if torch.cuda.is_available():
            g = g.to("cuda:0")
        g = espaloma_model(g)
        loss = loss_fn(g)
        loss.backward()
        optimizer.step()
        torch.save(espaloma_model.state_dict(), "%s.th" % idx_epoch)
/usr/local/lib/python3.7/site-packages/dgl/base.py:45: DGLWarning: From v0.5, DGLHeteroGraph is merged into DGLGraph. You can safely replace dgl.batch_hetero with dgl.batch
  return warnings.warn(message, category=category, stacklevel=1)
/usr/local/lib/python3.7/site-packages/dgl/base.py:45: DGLWarning: dgl.to_homo is deprecated. Please use dgl.to_homogeneous
  return warnings.warn(message, category=category, stacklevel=1)

Inspect

inspect_metric = esp.metrics.GraphMetric(
        base_metric=torch.nn.L1Loss(), # use mean-squared error loss
        between=['u', "u_ref"],         # between predicted and QM energies
        level="g", # compare on graph level
)
if torch.cuda.is_available():
    g_vl = g_vl.to("cuda:0")
    g_tr = g_tr.to("cuda:0")
loss_tr = []
loss_vl = []
for idx_epoch in range(10000):
    espaloma_model.load_state_dict(
        torch.load("%s.th" % idx_epoch)
    )

    espaloma_model(g_tr)
    loss_tr.append(inspect_metric(g_tr).item())

    espaloma_model(g_vl)
    loss_vl.append(inspect_metric(g_vl).item())
/usr/local/lib/python3.7/site-packages/dgl/base.py:45: DGLWarning: dgl.to_homo is deprecated. Please use dgl.to_homogeneous
  return warnings.warn(message, category=category, stacklevel=1)
import numpy as np
loss_tr = np.array(loss_tr) * 627.5
loss_vl = np.array(loss_vl) * 627.5
from matplotlib import pyplot as plt
plt.plot(loss_tr, label="train")
plt.plot(loss_vl, label="valid")
plt.yscale("log")
plt.legend()
<matplotlib.legend.Legend at 0x7fd8f0eebd90>
../_images/mm_fitting_small_31_1.png