API Reference¶
Reading and Writing files, graphs and coordinates¶
read_write.coords_library¶
- class LMDBDataset(lmdb_path)[source]¶
Bases:
object
A class for loading PyTorch data stored in an LMDB file. The code is originally intended for graph structured graphs that can work with pytorch_geometric data. But it should also load all types of PyTorch data.
This class enables on-the-fly loading of serialized data stored in LMDB format, providing an efficient way to handle large datasets that cannot fit into memory.
- parameters
lmdb_path (str): Path to the LMDB file containing the dataset.
- Attributes
lmdb_env (lmdb.Environment): The LMDB environment for data access. length (int): The total number of entries in the dataset.
- Methods
__len__(): Returns the total number of samples in the dataset. __getitem__(idx): Retrieves the sample at the specified index. split_data(train_size, random_seed, shuffle): Lazily returns train and test data.
Examples
This class provides an efficient way for loading huge datasets without consuming so much memory.
data = coords_library.LMDBDataset(path_to_lmdb)
Length of the dataset
print(len(data))
# Accessing a sample at index 0
sample = data[0]
print(sample.x.shape)
print(sample)
# Accessing a list of samples at different indexes
samples = data[[1,4,8,9,18, 50]]
- split_data(train_size=0.8, random_seed=None, shuffle=True)[source]¶
Lazily splits the dataset into train and test data with class-like behavior.
- Parameters:
train_size (float) – The proportion of the data to be used
set (as the training)
random_seed (int, optional) – A random seed for reproducibility
None). ((default is)
shuffle (bool) – Whether to shuffle the data before splitting
True). ((default is)
- Returns:
A tuple containing train data and test data.
- Return type:
tuple
- ase_coordinate(filename)[source]¶
Read any ASE readable file and returns coordinates and lattices which should be use for setting up AMS calculations.
- parameter
filename (string) : Name of file containing the coordinate
- Returns
ase_coord (list) : List of coordinate strings lattice (list) : List of lattice vectors strings
- ase_database_to_lmdb(ase_database, lmdb_path)[source]¶
Converts an ASE database into an LMDB file for efficient storage and retrieval.
- parameter
ase_database (str): path to ase database. lmdb_path (str): Path to the LMDB file where
the dataset will be saved.
- ase_to_pytorch_geometric(input_system)[source]¶
Convert an ASE Atoms object to a PyTorch Geometric graph
- parameters
input_system (ASE.Atoms or ASE.Atom or filename): The input system to be converted.
- returns
torch_geometric.data.Data: The converted PyTorch Geometric Data object.
- ase_to_xtb(ase_atoms)[source]¶
Create a gfn-xtb input from an ase atom object.
- parameter
ase_atoms (ase Atoms or Atom): The ase atoms object to be converted.
- return
xtb_coords = ase_to_xtb_coords(ase_atoms)
- calculate_distances(pair_indices, ase_atoms, mic=True)[source]¶
Calculate distances between pairs of atoms in an ase atoms object.
- charge_and_ase_from_ams_gfn(filename)[source]¶
Extract charge and ase atoms from an AMS gfn output file
- parameter
filename (string) : AMS gfn output file
- return
ase_atoms (ase Atoms object): ase atoms object charge (float): charge of the system
- check_periodicity(filename)[source]¶
Function to check periodicity in an scm output file
- parameter
filename (string) : Name of file containing the coordinate
- collect_coords(filename)[source]¶
Collect coordinates
- parameters
filename (string) : filename
- returns
elements (list) : list of elements positions (numpy array) : numpy array of positions cell (numpy array) : numpy array of cell parameters if present in the file
- compute_esp(atoms: Atoms, charges: ndarray, eps: float = 1e-12) ndarray [source]¶
Compute the electrostatic potential (ESP) at each atomic position due to all other atoms, using the point charge approximation and the minimum image convention under periodic boundary conditions (PBC).
The electrostatic potential at atom i is computed as:
V_i = Σ_{j ≠ i} (q_j / r_ij)
- where:
V_i is the electrostatic potential at atom i,
q_j is the charge of atom j,
r_ij is the shortest distance between atoms i and j, accounting for PBC.
- Parameters
- atomsase.Atoms
An ASE Atoms object representing the system. The simulation cell and periodic boundary conditions (PBC) must be defined.
charges : np.ndarray
A 1D NumPy array of atomic point charges (in units of elementary charge, e), with length equal to the number of atoms. - eps : float, optional A small constant added to distances to prevent division by zero (default is 1e-12 Å).
- Returns
- -np.ndarray
A 1D NumPy array containing the electrostatic potential at each atomic position (in units of e/Å).
- Notes
To convert the ESP values to electronvolts (eV), multiply the result by 14.3996
(since 1 e / (4 * π * ε₀ * Å) ≈ 14.3996 eV·Å/e). - Self-interactions are excluded by setting the diagonal of the distance matrix to ∞. - Distance calculations use ASE’s get_distances function, which applies the minimum image convention under PBC.
- data_from_aseDb(path_to_db, num_data=25000)[source]¶
Load data from ASE database and prepare it for training.
- parameters
path_to_db (str): Path to the ASE database file.
- returns
list: List of PyTorch Geometric Data objects for training.
- format_coords(coords, atom_labels)[source]¶
create coords containing symbols and positions
- parameters
coords (list) : list of coordinates atom_labels (list) : list of atom labels
- returns
coordinates (list) : list of formatted coordinates
- get_pairwise_connections(graph)[source]¶
Extract unique pairwise connections from an adjacency dictionary efficiently.
- Parameters
- graph (dict):
An adjacency dictionary where keys are nodes and values are arrays or lists of nodes representing neighbors.
- returns
- list of tuple
A list of unique pairwise connections, each represented as a tuple (i, j) where i < j.
- gjf_coordinate(filename)[source]¶
Reading coordinates from a gaussian .gjf file
- parameter
filename (string) : Name of file containing the coordinate
- Returns
coords : List of coordinate strings lattice (list) : List of lattice vectors strings
- list_train_test_split(data, train_size=0.8, random_seed=42, shuffle=True)[source]¶
A function that take Splits a list into train and test sets based on the specified train_size.
- parameter
data (list): The input list to split. train_size (float): The proportion of the data to be used as the training set (default is 0.8). random_seed (int, optional): A random seed for reproducibility (default is None). shuffle (bool): Whether to shuffle the data before splitting (default is True).
- return
train_data: indices of data to be selected for training. test_data: indices of data to be selected for testing.
- load_data_as_ase(filename)[source]¶
Load data as an ase atoms object parameter
- filename (string)Any file type that has been defined in this module
including ase readable filetypes
- return
ase_atoms : ase atoms object
- prepare_dataset(ase_obj, energy)[source]¶
Prepares a dataset from ASE Atoms objects and their corresponding energy values.
- parameters
ase_obj (ASE.Atoms): ASE Atoms object. energy (float): Energy value of the crystal structure.
- returns
torch_geometric.data.Data: PyTorch Geometric Data object with input features, edge indices, and energy value.
- pytorch_geometric_to_ase(data)[source]¶
Convert a PyTorch Geometric Data object back to an ASE Atoms object.
- Parameters
data (torch_geometric.data.Data): The PyTorch Geometric Data object.
- Returns
ase_atoms (ase.Atoms): The converted ASE Atoms object.
- qchemcout(filename)[source]¶
Read coordinates from qchem output file
- parameter
filename (string) : Name of file containing the coordinate
- Returns
coords : List of coordinate strings
- qchemin(filename)[source]¶
Read coordinates from qchem input file
- parameter
filename (string) : filename
- Returns
coords : list of coordinate strings
- read_and_return_ase_atoms(filename)[source]¶
Function to read the ase atoms
- parameter
filename: string
- scm_out(qcin)[source]¶
Extract coordinates from scm output files
- parameter
qcin (string) : scm output file
- return
coords (list) : list of coordinates lattice_coords (list) : list of lattice coordinates
- write_ase_atoms(ase_atoms, filename)[source]¶
Function to write the ase atoms
- parameter
ase_atoms: ase.Atoms object filename: string
read_write.filetyper¶
- class AtomsEncoder(*, skipkeys=False, ensure_ascii=True, check_circular=True, allow_nan=True, sort_keys=False, indent=None, separators=None, default=None)[source]¶
Bases:
JSONEncoder
Custom JSON encoder for serializing ASE Atoms objects and related data.
This encoder converts ASE Atoms objects into JSON-serializable dictionaries. It also handles the serialization of ASE Spacegroup objects.
Methods default(obj)
Serializes objects that are instances of ASE Atoms or Spacegroup, or falls back to the default JSON encoder for unsupported types.
- Examples
>>> from ase import Atoms >>> import json >>> atoms = Atoms('H2O', positions=[[0, 0, 0], [0, 0.76, 0], [0.76, 0, 0]]) >>> json_data = json.dumps(atoms, cls=AtomsEncoder) >>> print(json_data)
- append_contents(filename, output)[source]¶
Appends a list of strings to a file. This function appends the content of a list to a file, where each element in the list represents a line to be written. If the file does not exist, it will be created.
- parameters
- filenamestr
The path to the file where the content will be appended.
- outputlist of str
A list of strings to be appended to the file. Each string represents a line, and newline characters should be included if needed.
- append_json(new_data, filename)[source]¶
Appends new data to an existing JSON file. If the file does not exist or is empty, it creates a new JSON file with an empty dictionary. The function then updates the file with the provided data, overwriting existing keys if they are already present.
- parameters
- new_datadict
A dictionary containing the new data to append to the JSON file.
- filenamestr
The path to the JSON file.
- append_json_atom(data, filename)[source]¶
Appends or updates a JSON file with data containing an ASE Atoms object. If the file does not exist or is empty, it creates a new JSON file with an empty dictionary as the initial content. The function then updates the file with the provided data using the custom AtomsEncoder for serializing ASE Atoms objects.
- parameters
- datadict
A dictionary containing data with an ASE Atoms object or other serializable content.
- filenamestr
The path to the JSON file where the data will be appended.
- append_pickle(new_data, filename)[source]¶
Appends new data to a pickle file. This function appends new data to an existing pickle file. If the file does not exist, it will be created. Data is appended in binary format, ensuring that previously stored data is not overwritten.
- parameters
- new_dataobject
The Python object to append to the pickle file.
- filenamestr
The path to the pickle file where the data will be appended.
- combine_json_files(file1_path, file2_path, output_path)[source]¶
Queries data from a CSV (as a DataFrame) or JSON (as a dictionary).
This function retrieves data based on a reference key or value from either a dictionary (JSON-like object) or a pandas DataFrame (CSV-like object).
- parameters
- refstr or int
The reference key or value to query.
- data_objectdict or pandas.DataFrame
The data source, which can be a dictionary (for JSON) or a pandas DataFrame (for CSV).
- colstr, optional
The column name to query in the DataFrame. This parameter is required if the data source is a DataFrame and ignored if the data source is a dictionary.
- returns
- object
The queried data. For a dictionary, it returns the value associated with the reference key. For a DataFrame, it returns the rows where the specified column matches the reference value.
- csv_read(csv_file)[source]¶
Reads a CSV file and returns its content as a list of rows. This function reads the content of a CSV file and returns it as a list.
- parameters
- csv_filestr
The path to the CSV file to be read.
- returns
- list of list of str
A list of rows from the CSV file. Each row is a list of strings.
- get_contents(filename)[source]¶
Reads the content of a file and returns it as a list of lines. This function opens a file, reads its content line by line, and returns a list where each element is a line from the file, including newline characters.
- parameters
- filenamestr
The path to the file to be read.
- returns
- list of str
A list containing all lines in the file.
- get_section(contents, start_key, stop_key, start_offset=0, stop_offset=0)[source]¶
Extracts a section of lines from a list of strings between specified start and stop keys. This function searches through a list of strings (e.g., file contents) to find the last occurrence of a start key and extracts all lines up to and including the first occurrence of a stop key, with optional offsets for flexibility.
- parameters
- contentslist of str
A list of strings representing the lines of a file or text content.
- start_keystr
The key string that marks the start of the section.
- stop_keystr
The key string that marks the end of the section.
- start_offsetint, optional
The number of lines to include before the start key. Default is 0.
- stop_offsetint, optional
The number of lines to include after the stop key. Default is 0.
- returns
- list of str
The extracted lines from contents between the start and stop keys, including the offsets.
- json_to_aseatom(data, filename)[source]¶
Serialize an ASE Atoms object and write it to a JSON file. This function uses the custom AtomsEncoder to convert an ASE Atoms object into a JSON format and writes the serialized data to the specified file.
- parameters
- dataAtoms or dict
The ASE Atoms object or dictionary to serialize.
- filenamestr
The path to the JSON file where the serialized data will be saved.
- json_to_numpy(json_file)[source]¶
Deserializes a JSON file containing a NumPy array back into a NumPy array. This function reads a JSON file, deserializes the data, and converts it into a NumPy array.
- parameters
- json_filestr
The path to the JSON file containing the serialized NumPy array.
- returns
- numpy.ndarray
The deserialized NumPy array.
- list_2_json(list_obj, file_name)[source]¶
Writes a list to a JSON file. This function serializes a Python list and saves it to a specified JSON file.
- parameters
- list_objlist
The list to serialize and write to the file.
- file_namestr
The path to the JSON file where the list will be saved.
- load_data(filename)[source]¶
Automatically detects the file extension and loads the data using the appropriate function. This function reads a file and returns its content, choosing the correct loading method based on the file extension. Supported file formats include JSON, CSV, Pickle, Excel, and plain text files.
- parameters
- filenamestr
The path to the file to be loaded.
- returns
- object
The loaded data, which can be a dictionary, DataFrame, list, or other Python object, depending on the file type.
- numpy_to_json(ndarray, file_name)[source]¶
Serializes a NumPy array and saves it to a JSON file. This function converts a NumPy array into a list format, which is JSON-serializable, and writes it to the specified file.
- parameters
- ndarraynumpy.ndarray
The NumPy array to serialize.
- file_namestr
The path to the JSON file where the serialized data will be saved.
- pickle_load(filename)[source]¶
Loads and deserializes data from a pickle file. This function reads a pickle file and deserializes its content into a Python object.
- parameters
- filenamestr
The path to the pickle file to be loaded.
- returns
- object
The deserialized Python object from the pickle file
- put_contents(filename, output)[source]¶
Writes a list of strings into a file. This function writes the content of a list to a file, where each element in the list represents a line to be written. If the file already exists, it will be overwritten.
- parameters
- filenamestr
The path to the file where the content will be written.
- outputlist of str
A list of strings to be written to the file. Each string represents a line, and newline characters should be included if needed.
- query_data(ref, data_object, col=None)[source]¶
Queries data from a CSV (as a DataFrame) or JSON (as a dictionary).
This function retrieves data based on a reference key or value from either a dictionary (JSON-like object) or a pandas DataFrame (CSV-like object).
- parameters
- refstr or int
The reference key or value to query.
- data_objectdict or pandas.DataFrame
The data source, which can be a dictionary (for JSON) or a pandas DataFrame (for CSV).
- colstr, optional
The column name to query in the DataFrame. This parameter is required if the data source is a DataFrame and ignored if the data source is a dictionary.
- returns
- object
The queried data. For a dictionary, it returns the value associated with the reference key. For a DataFrame, it returns the rows where the specified column matches the reference value.
- read_json(file_name)[source]¶
Loads and reads a JSON file. This function opens a JSON file, reads its content, and deserializes it into a Python object (e.g., a dictionary or list).
- Parameters
- file_namestr
The path to the JSON file to be read.
- returns
- dict or list
The deserialized content of the JSON file.
- read_zip(zip_file)[source]¶
Reads and extracts the contents of a zip file.
This function opens a zip file and extracts its contents to the specified directory. If no directory is provided, it extracts to the current working directory.
- parameters
- zip_filestr
The path to the zip file to be read and extracted.
- extract_tostr, optional
The directory where the contents of the zip file will be extracted. If not provided, the current working directory is used.
- returns
- list of str
A list of file names contained in the zip file.
- remove_trailing_commas(json_file)[source]¶
Cleans trailing commas in a JSON file and returns the cleaned JSON string. This function reads a JSON file, removes trailing commas from objects and arrays, and returns the cleaned JSON string. It is useful for handling improperly formatted JSON files with trailing commas that are not compliant with the JSON standard.
- parameters
- json_filestr
The path to the JSON file to be cleaned.
- returns
- cleaned_json str
A cleaned JSON string with trailing commas removed.
- save_dict_msgpack(data: dict, filename: str) None [source]¶
Save a dictionary to a file using MessagePack.
- save_pickle(model, file_path)[source]¶
Saves a Python object to a file using pickle. This function serializes a Python object and saves it to a specified file in binary format using the pickle module.
- parameters
- modelobject
The Python object to serialize and save.
- file_pathstr
The path to the file where the object will be saved.
- write_json(json_obj, file_name)[source]¶
Writes a Python dictionary object to a JSON file.
This function serializes a Python dictionary into JSON format and writes it to the specified file and ensures that the JSON is human-readable with proper indentation.
- parameters
- json_objdict
The Python dictionary to serialize and write to the JSON file.
- file_namestr
The path to the JSON file where the data will be saved.
read_write.cheminfo2iupac¶
- main()[source]¶
Main function to parse command-line arguments, retrieve chemical information from PubChem, and write (or append) the data to a CSV file.
- print_helpful_information()[source]¶
Prints helpful information about using the chemical_parser script.
- pubchem_to_inchikey(identifier, name='smiles')[source]¶
A function that retrieves chemical properties from PubChem using the given identifier.
This function queries the PubChem database via the PubChemPy package and extracts a set of cheminformatic properties for the first matching compound. The properties include InChIKey, CID, IUPAC name, canonical SMILES, hydrogen bond donor and acceptor counts, rotatable bond count, and charge.
- Parameters:
identifier (str) – The chemical identifier to query (e.g., chemical name or SMILES).
search_type (str) – The type of identifier provided. Options include ‘name’ or ‘smiles’. Default is ‘name’.
- Returns:
- A dictionary containing the chemical properties if a matching compound is found;
otherwise, None.
- Return type:
dict or None
read_write.struct2iupac¶
- file2smile(filename)[source]¶
Function that reads a filename and uses openbabel to compute smi, inChi and inChiKey Parameters:
filename (str): name of file containing the structure. It can be in any ase readable file format as well as qchem out, AMS out, Gaussian out.
- Return:
smi (str): SMILE strings inchi : inchi hashing of structure inChiKey : 28 character hashing
- main()[source]¶
Main function to parse command-line arguments, retrieve chemical information from PubChem, and write (or append) the data to a CSV file.
- print_helpful_information()[source]¶
Prints helpful information about using the chemical_parser script.
- pubchem_to_inchikey(filename)[source]¶
A function that retrieves chemical properties from PubChem using the given identifier.
This function queries the PubChem database via the PubChemPy package and extracts a set of cheminformatic properties for the first matching compound. The properties include InChIKey, CID, IUPAC name, canonical SMILES, hydrogen bond donor and acceptor counts, rotatable bond count, and charge.
- Parameters:
identifier (str) – The chemical identifier to query (e.g., chemical name or SMILES).
search_type (str) – The type of identifier provided. Options include ‘name’ or ‘smiles’. Default is ‘name’.
- Returns:
- A dictionary containing the chemical properties if a matching compound is found;
otherwise, None.
- Return type:
dict or None
Graph Deep learning models¶
models for predicting Bond dissociations enthalpy
model.thermodynamic_stability¶
- class EarlyStopping(patience=5, delta=0.001, path='checkpoint.pth')[source]¶
Bases:
object
Implements early stopping to terminate training when validation loss stops improving.
- patience¶
Number of epochs/iteractions to wait for an improvement in validation loss before stopping. Default is 5.
- Type:
int
- delta¶
Minimum improvement in validation loss required to reset the patience counter. Default is 0.001.
- Type:
float
- best_loss¶
The lowest validation loss observed so far. Initialized to infinity.
- Type:
float
- counter¶
Tracks the number of consecutive epochs without improvement in validation loss.
- Type:
int
- early_stop¶
Flag indicating whether early stopping condition is met.
- Type:
bool
- path¶
File path to save the best model checkpoint. Default is ‘checkpoint.pth’.
- Type:
str
- __call__(val_loss, model)[source]¶
Evaluates the validation loss and decides whether to stop training or save a checkpoint.
- save_checkpoint(model)[source]¶
Saves the model’s state dictionary to the specified checkpoint file.
- save_checkpoint(model, optimizer, normalise_parameter)[source]¶
Saves the current model state to the checkpoint file.
- Parameters:
model (torch.nn.Module) – The PyTorch model being trained.
optimizer (optim.Optimizer) – Optimizer for updating model parameters.
normalise_parameter (float) – Normalization parameter for target values.
- Behavior:
Saves the model’s state dictionary to the file specified by path.
- class EnergyGNN_GAT2Conv(input_dim, hidden_dim, output_dim, edge_dim, heads=4, dropout=0.2)[source]¶
Bases:
Module
A Graph Neural Network class for predicting the thermodynamic stability of MOFs using Graph Attention Networks (GATv2).
- Arg:
input_dim (int): Number of input node features. hidden_dim (int): Number of hidden units in the GATv2 layers. output_dim (int): Number of output units (e.g., 1 for regression). heads (int, optional): Number of attention heads. Default is 1. dropout (float, optional): Dropout rate. Default is 0.2.
- forward(data)[source]¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class EnergyGNN_GIN(input_dim, hidden_dim, output_dim, edge_dim=None, heads=None, dropout=0.2)[source]¶
Bases:
Module
Implements a Graph Neural Network (GNN) using Graph Isomorphism Network (GIN) layers for graph-level prediction tasks such as regression or classification.
- input_dim¶
Dimensionality of input node features.
- Type:
int
Dimensionality of hidden layers.
- Type:
int
- output_dim¶
Dimensionality of the output.
- Type:
int
- dropout¶
Dropout rate for regularization. Default is 0.2.
- Type:
float
- Layers:
Three GINConv layers, each using an MLP for message passing.
Batch normalization after each GINConv layer.
Global mean pooling for aggregating node-level features to graph-level.
Fully connected layers for output prediction.
- forward(data)[source]¶
Defines the forward pass for the model.
- Parameters:
data (torch_geometric.data.Data) – Input graph data object containing: - x (torch.Tensor): Node features of shape [num_nodes, input_dim]. - edge_index (torch.Tensor): Edge indices of shape [2, num_edges]. - edge_attr (torch.Tensor, optional): Edge features (not used here). - batch (torch.Tensor): Batch indices for mini-batch training.
- Returns:
Final output predictions of shape [num_graphs, output_dim].
- Return type:
torch.Tensor
- EnergyGNN_GIN_main_model(path_to_lmdb, hidden_dim, learning_rate, batch_size, dropout, epoch, patience, print_every, save_path)[source]¶
Main function for training the EnergyGNN_GIN model.
- parameters:
path_to_lmdb (str): Path to the LMDB file containing the dataset. hidden_dim (int): Dimensionality of the hidden layer. learning_rate (float): Learning rate for the optimizer. batch_size (int): Batch size for training. dropout (float): Dropout rate for the model. epoch (int): Number of training epochs. patience (int): Patience for early stopping. print_every (int): Frequency of printing training loss and validation loss. save_path (str): Path to save the best model.
- **Bahaviors **
This function loads the dataset, trains the EnergyGNN_GIN model, and saves the best model.
- GAT2Conv_main_model(path_to_lmdb, hidden_dim, learning_rate, batch_size, dropout, heads, epoch, patience, print_every, save_path)[source]¶
Main function for training the GAT2Conv model.
- parameters:
path_to_lmdb (str): Path to the LMDB file containing the dataset. hidden_dim (int): Dimensionality of the hidden layer. learning_rate (float): Learning rate for the optimizer. batch_size (int): Batch size for training. dropout (float): Dropout rate for the model. heads (int): Number of attention heads for the GAT layer. epoch (int): Number of training epochs. patience (int): Patience for early stopping. print_every (int): Frequency of printing training loss and validation loss. save_path (str): Path to save the best model.
- **Bahaviors **
This function loads the dataset, trains the GAT2Conv model, and saves the best model.
- class GraphLatticeModel(input_dim, gnn_hidden_dim, lattice_hidden_dim, output_dim, num_layers=3, dropout=0.2)[source]¶
Bases:
Module
A Graph Neural Network (GNN) model that incorporates both graph-based features and lattice matrix features for graph-level prediction tasks such as property regression or classification.
- input_dim¶
Dimensionality of input node features.
- Type:
int
Dimensionality of GNN hidden layers.
- Type:
int
Dimensionality of hidden layers for lattice matrix processing.
- Type:
int
- output_dim¶
Dimensionality of the output (e.g., target properties).
- Type:
int
- num_layers¶
Number of GNN layers. Default is 3.
- Type:
int
- dropout¶
Dropout rate for regularization. Default is 0.2.
- Type:
float
- forward(data)[source]¶
Performs a forward pass through the model using the input graph and lattice data.
- forward(data)[source]¶
Performs a forward pass of the model.
- Parameters:
data (torch_geometric.data.Data) – Input data containing: - x (torch.Tensor): Node features of shape [num_nodes, input_dim]. - edge_index (torch.Tensor): Edge indices of shape [2, num_edges]. - edge_attr (torch.Tensor, optional): Edge features (not used here). - batch (torch.Tensor): Batch indices for mini-batch training. - lattice (torch.Tensor): Lattice matrix of shape [num_graphs, 3, 3].
- Returns:
Final predictions of shape [num_graphs, output_dim].
- Return type:
torch.Tensor
- GraphLatticeModel_main_model(path_to_lmdb, hidden_dim, lattice_hidden_dim, num_layers, learning_rate, batch_size, dropout, epoch, patience, print_every, save_path)[source]¶
Main function for training the EnergyGNN_GIN model.
- parameters:
path_to_lmdb (str): Path to the LMDB file containing the dataset. hidden_dim (int): Dimensionality of the hidden layer. learning_rate (float): Learning rate for the optimizer. batch_size (int): Batch size for training. dropout (float): Dropout rate for the model. epoch (int): Number of training epochs. patience (int): Patience for early stopping. print_every (int): Frequency of printing training loss and validation loss. save_path (str): Path to save the best model.
- **Bahaviors **
This function loads the dataset, trains the EnergyGNN_GIN model, and saves the best model.
- display_model_helper(model)[source]¶
Display clear and concise information about the selected model and its key parameters.
- Parameters:
model (str) – The name of the model. Options are: - “EnergyGNN_GAT2Conv”: Graph Attention Networks (GATv2) based model. - “EnergyGNN_GIN”: Graph Isomorphism Network (GIN) based model. - “GraphLatticeModel”: Combines graph features with lattice matrix processing.
- Instructions:
Review the description for each model.
Identify the required parameters for your use case.
Ensure you provide the necessary inputs when configuring the model.
- Example Usage:
display_model_helper(“EnergyGNN_GIN”)
- entry_point()[source]¶
Entry point for the training script. Dynamically calls the appropriate main function based on the selected model.
- evaluate(model, dataloader, criterion, device)[source]¶
Evaluate the model using the given data and loss function, and compute accuracy.
- parameters:
model (nn.Module): The trained model to evaluate. dataloader (DataLoader): DataLoader for batching the dataset during evaluation. criterion (nn.Module): Loss function (e.g., CrossEntropyLoss) to compute the evaluation loss. device (torch.device): The device (CPU or GPU) for computation.
- returns:
- tuple: A tuple containing the average evaluation loss and accuracy.
(average_loss, accuracy)
- inverse_normalize(predictions, normalization_params, method='z-score')[source]¶
Inverse the normalization of predictions to the original scale.
- Parameters:
predictions (Tensor) – The normalized predictions.
normalization_params (dict) – The parameters used for normalization.
method (str) – The normalization method (‘min-max’ or ‘z-score’).
- Returns:
Predictions in the original scale.
- Return type:
Tensor
- load_dataset(path_to_lmdb, batch_size, train_size=0.9, random_seed=42, shuffle=True, normalize='full')[source]¶
Loads a dataset from an LMDB file and splits it into training, validation, and test sets.
The function uses the coords_library.LMDBDataset to load the dataset and splits it into training and test datasets. The training dataset is further split into training and validation sets. Data loaders are created for the training and validation datasets.
- Parameters:
path_to_lmdb (str) – Path to the LMDB file containing the dataset.
batch_size (int) – Batch size for the DataLoader.
train_size (float, optional) – Fraction of the data to use for training. The rest is used for testing. Default is 0.8.
random_seed (int, optional) – Random seed for splitting the data. Ensures reproducibility. Default is 42.
shuffle (bool, optional) – Whether to shuffle the data before splitting. Default is True.
normalize (str, optional) – Normalization method to use. Can be ‘full’ for full normalization or ‘batch’ for batch normalization. Default is ‘full’.
- Returns:
train_loader (DataLoader): DataLoader for the training dataset.
val_loader (DataLoader): DataLoader for the validation dataset.
test_dataset (Dataset): Dataset object containing the test data.
- Return type:
tuple
- load_model(path='model.pth', device='cpu')[source]¶
Load a saved model and optimizer state from a file.
- parameters:
path (str, optional): Path to load the model. Default is “model.pth”. device (torch.device, optional): The device (CPU or GPU) for computation. Default is “cpu”.
- returns:
tuple: The loaded model and optimizer.
- normalize_data(dataset, method='z-score')[source]¶
Normalize the target values (data.y) in the dataset.
- Parameters:
dataset (Dataset) – The dataset object containing the data.
method (str) – The normalization method (‘min-max’ or ‘z-score’).
- Returns:
Normalized dataset and the normalization parameters.
- Return type:
tuple
- parse_arguments()[source]¶
Parse command-line arguments for training the GNN model.
- Returns:
Parsed command-line arguments with detailed descriptions.
- Return type:
argparse.Namespace
- save_model(model, optimizer, normalise_parameter, path='model.pth')[source]¶
Save the trained model and optimizer state to a file.
- parameters:
model (nn.Module): The trained GNN model to save. optimizer (optim.Optimizer): Optimizer for updating model parameters. path (str, optional): Path to save the model. Default is “model.pth”.
- train(model, dataloader, optimizer, criterion, device)[source]¶
Train the model using the given data and optimizer.
- parameters:
model (nn.Module): The GNN model to train dataloader (DataLoader): DataLoader for batching the dataset during training. optimizer (optim.Optimizer): Optimizer for updating model parameters. criterion (nn.Module): Loss function (e.g., MSELoss) to compute the training loss. device (torch.device): The device (CPU or GPU) for computation.
- returns:
float: The average training loss over the epoch.
- transform_target(test_data, normalize_param, method='z-score')[source]¶
Transform the target values in the test dataset according to the normalization parameters.
- Parameters:
test_data (Dataset) – The test dataset containing the target values.
normalize_param (dict) – The normalization parameters (mean and std).
method (str, optional) – The normalization method (‘z-score’ or’min-max’). Default is ‘z-score’.
- Returns:
The transformed test dataset.
- Return type:
Dataset