Source code for espaloma.data.normalize

# =============================================================================
# IMPORTS
# =============================================================================
import abc

import torch

import espaloma as esp


# =============================================================================
# BASE CLASSES
# =============================================================================
[docs]class BaseNormalize(abc.ABC): """Base class for normalizing operation."""
[docs] def __init__(self): super(BaseNormalize, self).__init__()
@abc.abstractmethod def _prepare(self): # NOTE: # `_norm` and `_unnorm` are assigned here raise NotImplementedError
# ============================================================================= # MODULE CLASSES # =============================================================================
[docs]class DatasetNormalNormalize(BaseNormalize): """Normalizing operation based on a dataset of molecules, assuming parameters having normal distribution. Parameters ---------- dataset : `espaloma.data.dataset.Dataset` The dataset we base on to calculate the statistics of parameter distributions. Attributes ---------- norm : normalize function unnorm : unnormalize function """
[docs] def __init__(self, dataset): super(DatasetNormalNormalize, self).__init__() self.dataset = dataset self._prepare()
def _prepare(self): """ Calculate the statistics from dataset """ # grab the collection of graphs in the dataset, batched g = self.dataset.batch(self.dataset.graphs) self.statistics = {term: {} for term in ["n1", "n2", "n3", "n4"]} # calculate statistics for term in ["n1", "n2", "n3", "n4"]: # loop through terms for key in g.nodes[term].data.keys(): # loop through parameters if not key.endswith("ref"): # pass non-parameters continue self.statistics[term][ key.replace("_ref", "_mean") ] = torch.mean(g.nodes[term].data[key], axis=0) self.statistics[term][ key.replace("_ref", "_std") ] = torch.std(g.nodes[term].data[key], axis=0) # get normalize and unnormalize functions def norm(g): for term in ["n1", "n2", "n3", "n4"]: # loop through terms for key in g.nodes[ term ].data.keys(): # loop through parameters if not key.endswith("ref"): # pass non-parameters continue g.nodes[term].data[key] = ( g.nodes[term].data[key] - self.statistics[term][key.replace("_ref", "_mean")] ) / self.statistics[term][key.replace("_ref", "_std")] return g def unnorm(g): for term in ["n1", "n2", "n3", "n4"]: # loop through terms for key in g.nodes[ term ].data.keys(): # loop through parameters if key + "_mean" in self.statistics[term]: g.nodes[term].data[key] = ( g.nodes[term].data[key] * self.statistics[term][key + "_std"] + self.statistics[term][key + "_mean"] ) # # elif '_ref' in key \ # and key.replace('_ref', '_mean')\ # in self.statistics[term]: # # g.nodes[term].data[key]\ # = g.nodes[term].data[key]\ # * self.statistics[term][ # key.replace('_ref', '_std')]\ # + self.statistics[term][ # key.replace('_ref', '_mean')] return g # point normalize and unnormalize functions to `self` self.norm = norm self.unnorm = unnorm
[docs]class DatasetLogNormalNormalize(BaseNormalize): """Normalizing operation based on a dataset of molecules, assuming parameters having log normal distribution. Parameters ---------- dataset : `espaloma.data.dataset.Dataset` The dataset we base on to calculate the statistics of parameter distributions. Attributes ---------- norm : normalize function unnorm : unnormalize function """
[docs] def __init__(self, dataset): super(DatasetLogNormalNormalize, self).__init__() self.dataset = dataset self._prepare()
def _prepare(self): """ Calculate the statistics from dataset """ # grab the collection of graphs in the dataset, batched g = self.dataset.batch(self.dataset.graphs) self.statistics = {term: {} for term in ["n1", "n2", "n3", "n4"]} # calculate statistics for term in ["n1", "n2", "n3", "n4"]: # loop through terms for key in g.nodes[term].data.keys(): # loop through parameters if not key.endswith("ref"): # pass non-parameters continue self.statistics[term][ key.replace("_ref", "_mean") ] = torch.mean(g.nodes[term].data[key].log(), axis=0) self.statistics[term][ key.replace("_ref", "_std") ] = torch.std(g.nodes[term].data[key].log(), axis=0) # get normalize and unnormalize functions def norm(g): for term in ["n1", "n2", "n3", "n4"]: # loop through terms for key in g.nodes[ term ].data.keys(): # loop through parameters if not key.endswith("ref"): # pass non-parameters continue g.nodes[term].data[key] = ( g.nodes[term].data[key].log() - self.statistics[term][key.replace("_ref", "_mean")] ) / self.statistics[term][key.replace("_ref", "_std")] return g def unnorm(g): for term in ["n1", "n2", "n3", "n4"]: # loop through terms for key in g.nodes[ term ].data.keys(): # loop through parameters if key + "_mean" in self.statistics[term]: g.nodes[term].data[key] = torch.exp( g.nodes[term].data[key] * self.statistics[term][key + "_std"].to( g.nodes[term].data[key].device ) + self.statistics[term][key + "_mean"].to( g.nodes[term].data[key].device ) ) # # elif '_ref' in key \ # and key.replace('_ref', '_mean')\ # in self.statistics[term]: # # g.nodes[term].data[key]\ # = torch.exp( # g.nodes[term].data[key]\ # * self.statistics[term][ # key.replace('_ref', '_std')]\ # + self.statistics[term][ # key.replace('_ref', '_mean')]) return g # point normalize and unnormalize functions to `self` self.norm = norm self.unnorm = unnorm
# ============================================================================= # PRESETS # =============================================================================
[docs]class ESOL100NormalNormalize(DatasetNormalNormalize):
[docs] def __init__(self): super(ESOL100NormalNormalize, self).__init__( dataset=esp.data.esol(first=100).apply( esp.graphs.legacy_force_field.LegacyForceField( "smirnoff99Frosst-1.1.0" ).parametrize, in_place=True, ) )
[docs]class ESOL100LogNormalNormalize(DatasetLogNormalNormalize):
[docs] def __init__(self): super(ESOL100LogNormalNormalize, self).__init__( dataset=esp.data.esol(first=100).apply( esp.graphs.legacy_force_field.LegacyForceField( "smirnoff99Frosst-1.1.0" ).parametrize, in_place=True, ) )
[docs]class NotNormalize(BaseNormalize):
[docs] def __init__(self): super(NotNormalize).__init__() self._prepare()
def _prepare(self): self.norm = lambda x: x self.unnorm = lambda x: x
[docs]class PositiveNotNormalize(BaseNormalize):
[docs] def __init__(self): super(PositiveNotNormalize, self).__init__() self._prepare()
def _prepare(self): # get normalize and unnormalize functions def norm(g): for term in ["n1", "n2", "n3", "n4"]: # loop through terms for key in g.nodes[ term ].data.keys(): # loop through parameters if not key.endswith("ref"): # pass non-parameters continue g.nodes[term].data[key] = g.nodes[term].data[key].log() return g def unnorm(g): for term in [ "n2", "n3", ]: # loop through terms for key in g.nodes[ term ].data.keys(): # loop through parameters if key == "k" or key == "eq": g.nodes[term].data[key] = torch.exp( g.nodes[term].data[key] ) return g # point normalize and unnormalize functions to `self` self.norm = norm self.unnorm = unnorm