Source code for espaloma.data.dataset

# =============================================================================
# IMPORTS
# =============================================================================
import abc

import torch

import espaloma as esp


# =============================================================================
# MODULE CLASSES
# =============================================================================
[docs]class Dataset(abc.ABC, torch.utils.data.Dataset): """The base class of map-style dataset. Parameters ---------- graphs : List objects in the dataset Methods ------- shuffle Randomly shuffle the graphs in the dataset. apply(fn, in_place=True) Apply a function to every graph in the dataset. If `in_place=True`, modify the graph in-place. split(partitions) Split the dataset into partitions subsample(ratio, seed=None) Subsample the dataset. save(path) Save the dataset to a local path. load(path) Load a dataset from local path. Note ---- This also supports iterative-style dataset by deleting `__getitem__` and `__len__` function. Attributes ---------- transforms : an iterable of callables that transforms the input. the `__getiem__` method applies these transforms later. Examples -------- >>> data = Dataset([esp.Graph("C")]) """
[docs] def __init__(self, graphs=None): super(Dataset, self).__init__() self.graphs = graphs self.transforms = None
def __len__(self): # 0 len if no graphs if self.graphs is None: return 0 else: return len(self.graphs) def __getitem__(self, idx): if self.graphs is None: raise RuntimeError("Empty molecule dataset.") if isinstance(idx, int): # sinlge element if self.transforms is None: # when no transform act like list return self.graphs[idx] else: graph = self.graphs[idx] # nested transforms for transform in self.transforms: graph = transform(graph) return graph elif isinstance(idx, slice): # implement slicing if self.transforms is None: # return a Dataset object rather than list return self.__class__(graphs=self.graphs[idx]) else: graphs = [] for graph in self.graphs[idx]: # nested transforms for transform in self.transforms: graph = transform(graph) graphs.append(graph) return self.__class__(graphs=graphs) elif isinstance(idx, list): # implement slicing if self.transforms is None: # return a Dataset object rather than list return self.__class__( graphs=[self.graphs[_idx] for _idx in idx] ) else: graphs = [] for _idx in idx: graph = self[_idx] # nested transforms for transform in self.transforms: graph = transform(graph) graphs.append(graph) return self.__class__(graphs=graphs) def __iter__(self): if self.transforms is None: return iter(self.graphs) else: # TODO: # is this efficient? graphs = iter(self.graphs) for transform in self.transforms: graphs = map(transform, graphs) return graphs
[docs] def shuffle(self, seed=None): import random from random import shuffle if seed is not None: random.seed(seed) shuffle(self.graphs) return self
[docs] def apply(self, fn, in_place=False): r"""Apply functions to the elements of the dataset. Parameters ---------- fn : callable Note ---- If in_place is False, `fn` is added to the `transforms` else it is applied to elements and modifies them. """ assert callable(fn) assert isinstance(in_place, bool) if in_place is False: # add to list of transforms if self.transforms is None: self.transforms = [] self.transforms.append(fn) else: # modify in-place # self.graphs = list(map(fn, self.graphs)) _graphs = [] for graph in self.graphs: try: _graphs.append(fn(graph)) except: pass self.graphs = _graphs return self # to allow grammar: ds = ds.apply(...)
[docs] def split(self, partition): """Split the dataset according to some partition. Parameters ---------- partition : sequence of integers or floats """ n_data = len(self) partition = [int(n_data * x / sum(partition)) for x in partition] ds = [] idx = 0 for p_size in partition: ds.append(self[idx : idx + p_size]) idx += p_size return ds
[docs] def subsample(self, ratio, seed=None): """Subsample the dataset according to some ratio. Parameters ---------- ratio : float Ratio between the size of the subsampled dataset and the original dataset. """ n_data = len(self) idxs = list(range(n_data)) import random random.seed(seed) _idxs = random.choices(idxs, k=int(n_data * ratio)) return self[_idxs]
[docs] def save(self, path): """Save dataset to path. Parameters ---------- path : path-like object """ import pickle with open(path, "wb") as f_handle: pickle.dump(self.graphs, f_handle)
[docs] def regenerate_impropers(self, improper_def='smirnoff'): """ Regenerate the improper nodes for all graphs. Parameters ---------- improper_def : str Which convention to use for permuting impropers. """ from espaloma.graphs.utils.regenerate_impropers import regenerate_impropers for g in self.graphs: regenerate_impropers(g, improper_def)
[docs] @classmethod def load(cls, path): """Load path to dataset. Parameters ---------- """ import pickle with open(path, "rb") as f_handle: graphs = pickle.load(f_handle) return cls(graphs)
def __add__(self, x): return self.__class__(self.graphs + x.graphs)
[docs]class GraphDataset(Dataset): """Dataset with additional support for only viewing certain attributes as `torch.utils.data.DataLoader` Methods ------- view(collate_fn, *args, **kwargs) Provide a `torch.utils.data.DataLoader` view of the dataset. Note """
[docs] def __init__(self, graphs=[], first=None): super(GraphDataset, self).__init__() from openff.toolkit.topology import Molecule if all( isinstance(graph, Molecule) or isinstance(graph, str) for graph in graphs ): if first is None or first == -1: graphs = [esp.Graph(graph) for graph in graphs] else: graphs = [esp.Graph(graph) for graph in graphs[:first]] self.graphs = graphs
@staticmethod def batch(graphs): import dgl if all(isinstance(graph, esp.graphs.graph.Graph) for graph in graphs): return dgl.batch_hetero([graph.heterograph for graph in graphs]) elif all(isinstance(graph, dgl.DGLGraph) for graph in graphs): return dgl.batch(graphs) elif all(isinstance(graph, dgl.DGLHeteroGraph) for graph in graphs): return dgl.batch_hetero(graphs) else: raise RuntimeError( "Can only batch DGLGraph or DGLHeterograph," "now have %s" % type(graphs[0]) )
[docs] def view(self, collate_fn="graph", *args, **kwargs): """Provide a data loader. Parameters ---------- collate_fn : callable or string see `collate_fn` argument for `torch.utils.data.DataLoader` """ if collate_fn == "graph": collate_fn = self.batch elif collate_fn == "homograph": def collate_fn(graphs): graph = self.batch([g.homograph for g in graphs]) return graph elif collate_fn == "graph-typing": def collate_fn(graphs): graph = self.batch(graphs) y = graph.ndata["legacy_typing"] return graph, y elif collate_fn == "graph-typing-loss": loss_fn = torch.nn.CrossEntropyLoss() def collate_fn(graphs): graph = self.batch(graphs) loss = lambda _graph: loss_fn( _graph.ndata["nn_typing"], graph.ndata["legacy_typing"] ) return graph, loss return torch.utils.data.DataLoader( dataset=self, collate_fn=collate_fn, *args, **kwargs )
[docs] def save(self, path): import os os.mkdir(path) for idx, graph in enumerate(self.graphs): graph.save(path + "/" + str(idx))
[docs] @classmethod def load(cls, path): import os paths = os.listdir(path) paths = [_path for _path in paths] graphs = [] for _path in paths: graphs.append(esp.Graph.load(path + "/" + _path)) return cls(graphs)