Toy experiment: Molecular mechanics (MM) fitting on subsampled PhAlkEthOH dataset.

Open in Google Colab:

This notebook is intended to recover the MM fitting behavior in

To assess how well Espaloma can learn to reproduce an MM force field from a limited amount of data, we selected a chemical dataset of limited complexity—PhAlkEthOH—which consists of linear and cyclic molecules containing phenyl rings, small alkanes, ethers, and alcohols composed of only the elements carbon, oxygen, and hydrogen. We generated a set of conformational snapshots for each molecule using short high-temperature molecular dynamics simulations at 300~K initiated from multiple conformations to ensure adequate sampling of conformers. The AlkEthOH dataset was randomly partitioned (by molecules) into 80% training, 10% validation, and 10% test molecules, with 100 snapshots/molecule, and an Espaloma model was trained with early stopping via monitoring for a decrease in accuracy in the validation set.

Installation and imports

# install conda
! pip install -q condacolab
import condacolab
! mamba install --yes --strict-channel-priority --channel jaimergp/label/unsupported-cudatoolkit-shim --channel omnia --channel omnia/label/cuda100 --channel dglteam --channel numpy openmm openmmtools openmmforcefields rdkit openff-toolkit dgl-cuda10.0 qcportal
! git clone
import torch
import sys
import espaloma as esp
Load dataset

Here we load the PhAlKeThoh dataset and shuffle before splitting into training, validation, and test (80%:10%:10%)

! wget
! unzip
ds ="phalkethoh")
ds_tr, ds_vl, ds_te = ds.split([8, 1, 1])
A training dataloader is constructed with batch_size=100

ds_tr_loader = ds_tr.view(batch_size=100, shuffle=True)
g_tr = next(iter(ds_tr.view(batch_size=len(ds_tr))))
g_vl = next(iter(ds_vl.view(batch_size=len(ds_vl))))
Define model

Define Espaloma stage I: graph -> atom latent representation

representation = esp.nn.Sequential("SAGEConv"), # use SAGEConv implementation in DGL
    config=[128, "relu", 128, "relu", 128, "relu"], # 3 layers, 128 units, ReLU activation

Define Espaloma stage II and III: atom latent representation -> bond, angle, and torsion representation and parameters. And compose all three Espaloma stages into an end-to-end model.

readout = esp.nn.readout.janossy.JanossyPooling(
    in_features=128, config=[128, "relu", 128, "relu", 128, "relu"],
    out_features={              # define modular MM parameters Espaloma will assign
        1: {"e": 1, "s": 1}, # atom hardness and electronegativity
        2: {"log_coefficients": 2}, # bond linear combination, enforce positive
        3: {"log_coefficients": 2}, # angle linear combination, enforce positive
        4: {"k": 6}, # torsion barrier heights (can be positive or negative)

espaloma_model = torch.nn.Sequential(
                 representation, readout, esp.nn.readout.janossy.ExpCoefficients(),
if torch.cuda.is_available():
    espaloma_model = espaloma_model.cuda()

Loss function is specified as the MSE between predicted and reference energy.

loss_fn = esp.metrics.GraphMetric(
        base_metric=torch.nn.MSELoss(), # use mean-squared error loss
        between=['u', "u_ref"],         # between predicted and QM energies
        level="g", # compare on graph level

Define optimizer

optimizer = torch.optim.Adam(espaloma_model.parameters(), 1e-4)

Train it!

for idx_epoch in range(10000):
    for g in ds_tr_loader:
        if torch.cuda.is_available():
            g ="cuda:0")
        g = espaloma_model(g)
        loss = loss_fn(g)
        optimizer.step(), "" % idx_epoch)
inspect_metric = esp.metrics.GraphMetric(
        base_metric=torch.nn.L1Loss(), # use mean-squared error loss
        between=['u', "u_ref"],         # between predicted and QM energies
        level="g", # compare on graph level
if torch.cuda.is_available():
    g_vl ="cuda:0")
    g_tr ="cuda:0")
loss_tr = []
loss_vl = []
for idx_epoch in range(10000):
        torch.load("" % idx_epoch)


import numpy as np
loss_tr = np.array(loss_tr) * 627.5
loss_vl = np.array(loss_vl) * 627.5
from matplotlib import pyplot as plt
plt.plot(loss_tr, label="train")
plt.plot(loss_vl, label="valid")
