Source code for espaloma.graphs.graph

# =============================================================================
# IMPORTS
# =============================================================================
import abc
import openff.toolkit

import espaloma as esp


# =============================================================================
# MODULE CLASSES
# =============================================================================
[docs]class BaseGraph(abc.ABC): """ Base class of graph. """
[docs] def __init__(self): super(BaseGraph, self).__init__()
[docs]class Graph(BaseGraph): """ A unified graph object that support translation to and from message-passing graphs and MM factor graph. Methods ------- save(path) Save graph to file. load(path) Load a graph from path. Note ---- This object provides access to popular attributes of homograph and heterograph. This object also provides access to `ndata` and `edata` from the heterograph. Examples -------- >>> g0 = esp.Graph("C") >>> g1 = esp.Graph(Molecule.from_smiles("C")) >>> assert g0 == g1 """
[docs] def __init__(self, mol=None, homograph=None, heterograph=None): # TODO : more pythonic way allow multiple constructors: # Graph.from_smiles(...), Graph.from_mol(...), Graph.from_homograph(...), ... # rather than Graph(mol=None, homograph=None, ...) # input molecule if isinstance(mol, str): from openff.toolkit.topology import Molecule mol = Molecule.from_smiles(mol, allow_undefined_stereo=True) if mol is not None and homograph is None and heterograph is None: homograph = self.get_homograph_from_mol(mol) if homograph is not None and heterograph is None: heterograph = self.get_heterograph_from_graph_and_mol( homograph, mol ) self.mol = mol self.homograph = homograph self.heterograph = heterograph
[docs] def save(self, path): import os import json import dgl os.mkdir(path) dgl.save_graphs(path + "/homograph.bin", [self.homograph]) dgl.save_graphs(path + "/heterograph.bin", [self.heterograph]) with open(path + "/mol.json", "w") as f_handle: json.dump(self.mol.to_json(), f_handle)
[docs] @classmethod def load(cls, path): import json import dgl homograph = dgl.load_graphs(path + "/homograph.bin")[0][0] heterograph = dgl.load_graphs(path + "/heterograph.bin")[0][0] with open(path + "/mol.json", "r") as f_handle: mol = json.load(f_handle) from openff.toolkit.topology import Molecule try: mol = Molecule.from_json(mol) except: mol = Molecule.from_dict(mol) g = cls(mol=mol, homograph=homograph, heterograph=heterograph) return g
@staticmethod def get_homograph_from_mol(mol): assert isinstance( mol, openff.toolkit.topology.Molecule ), "mol can only be OFF Molecule object." # TODO: # rewrite this using OFF-generic grammar # graph = esp.graphs.utils.read_homogeneous_graph.from_rdkit_mol( # mol.to_rdkit() # ) graph = ( esp.graphs.utils.read_homogeneous_graph.from_openff_toolkit_mol( mol ) ) return graph @staticmethod def get_heterograph_from_graph_and_mol(graph, mol): import dgl assert isinstance( graph, dgl.DGLGraph ), "graph can only be dgl Graph object." heterograph = esp.graphs.utils.read_heterogeneous_graph.from_homogeneous_and_mol( graph, mol ) return heterograph # # @property # def mol(self): # return self._mol # # @property # def homograph(self): # return self._homograph # # @property # def heterograph(self): # return self._heterograph @property def ndata(self): return self.homograph.ndata @property def edata(self): return self.homograph.edata @property def nodes(self): return self.heterograph.nodes