API Reference

Reading and Writing files, graphs and coordinates

read_write.coords_library

class LMDBDataset(lmdb_path)[source]

Bases: object

A class for loading PyTorch data stored in an LMDB file. The code is originally intended for graph structured graphs that can work with pytorch_geometric data. But it should also load all types of PyTorch data.

This class enables on-the-fly loading of serialized data stored in LMDB format, providing an efficient way to handle large datasets that cannot fit into memory.

parameters

lmdb_path (str): Path to the LMDB file containing the dataset.

Attributes

lmdb_env (lmdb.Environment): The LMDB environment for data access. length (int): The total number of entries in the dataset.

Methods

__len__(): Returns the total number of samples in the dataset. __getitem__(idx): Retrieves the sample at the specified index. split_data(train_size, random_seed, shuffle): Lazily returns train and test data.

Examples

This class provides an efficient way for loading huge datasets without consuming so much memory.

data = coords_library.LMDBDataset(path_to_lmdb)

Length of the dataset

print(len(data))

# Accessing a sample at index 0

sample = data[0]

print(sample.x.shape)

print(sample)

# Accessing a list of samples at different indexes

samples = data[[1,4,8,9,18, 50]]

split_data(train_size=0.8, random_seed=None, shuffle=True)[source]

Lazily splits the dataset into train and test data with class-like behavior.

Parameters:
  • train_size (float) – The proportion of the data to be used

  • set (as the training)

  • random_seed (int, optional) – A random seed for reproducibility

  • None). ((default is)

  • shuffle (bool) – Whether to shuffle the data before splitting

  • True). ((default is)

Returns:

A tuple containing train data and test data.

Return type:

tuple

ase_coordinate(filename)[source]

Read any ASE readable file and returns coordinates and lattices which should be use for setting up AMS calculations.

parameter

filename (string) : Name of file containing the coordinate

Returns

ase_coord (list) : List of coordinate strings lattice (list) : List of lattice vectors strings

ase_database_to_lmdb(ase_database, lmdb_path)[source]

Converts an ASE database into an LMDB file for efficient storage and retrieval.

parameter

ase_database (str): path to ase database. lmdb_path (str): Path to the LMDB file where

the dataset will be saved.

ase_to_pytorch_geometric(input_system)[source]

Convert an ASE Atoms object to a PyTorch Geometric graph

parameters

input_system (ASE.Atoms or ASE.Atom or filename): The input system to be converted.

returns

torch_geometric.data.Data: The converted PyTorch Geometric Data object.

ase_to_xtb(ase_atoms)[source]

Create a gfn-xtb input from an ase atom object.

parameter

ase_atoms (ase Atoms or Atom): The ase atoms object to be converted.

return

xtb_coords = ase_to_xtb_coords(ase_atoms)

calculate_distances(pair_indices, ase_atoms, mic=True)[source]

Calculate distances between pairs of atoms in an ase atoms object.

charge_and_ase_from_ams_gfn(filename)[source]

Extract charge and ase atoms from an AMS gfn output file

parameter

filename (string) : AMS gfn output file

return

ase_atoms (ase Atoms object): ase atoms object charge (float): charge of the system

check_periodicity(filename)[source]

Function to check periodicity in an scm output file

parameter

filename (string) : Name of file containing the coordinate

collect_coords(filename)[source]

Collect coordinates

parameters

filename (string) : filename

returns

elements (list) : list of elements positions (numpy array) : numpy array of positions cell (numpy array) : numpy array of cell parameters if present in the file

compute_esp(atoms: Atoms, charges: ndarray, eps: float = 1e-12) ndarray[source]

Compute the electrostatic potential (ESP) at each atomic position due to all other atoms, using the point charge approximation and the minimum image convention under periodic boundary conditions (PBC).

The electrostatic potential at atom i is computed as:

V_i = Σ_{j ≠ i} (q_j / r_ij)

where:
  • V_i is the electrostatic potential at atom i,

  • q_j is the charge of atom j,

  • r_ij is the shortest distance between atoms i and j, accounting for PBC.

Parameters
  • atomsase.Atoms

    An ASE Atoms object representing the system. The simulation cell and periodic boundary conditions (PBC) must be defined.

  • charges : np.ndarray

A 1D NumPy array of atomic point charges (in units of elementary charge, e), with length equal to the number of atoms. - eps : float, optional A small constant added to distances to prevent division by zero (default is 1e-12 Å).

Returns
-np.ndarray

A 1D NumPy array containing the electrostatic potential at each atomic position (in units of e/Å).

Notes
  • To convert the ESP values to electronvolts (eV), multiply the result by 14.3996

(since 1 e / (4 * π * ε₀ * Å) ≈ 14.3996 eV·Å/e). - Self-interactions are excluded by setting the diagonal of the distance matrix to ∞. - Distance calculations use ASE’s get_distances function, which applies the minimum image convention under PBC.

coordinate_definition(filename)[source]

define how coordinates should be extracted

data_from_aseDb(path_to_db, num_data=25000)[source]

Load data from ASE database and prepare it for training.

parameters

path_to_db (str): Path to the ASE database file.

returns

list: List of PyTorch Geometric Data objects for training.

format_coords(coords, atom_labels)[source]

create coords containing symbols and positions

parameters

coords (list) : list of coordinates atom_labels (list) : list of atom labels

returns

coordinates (list) : list of formatted coordinates

get_pairwise_connections(graph)[source]

Extract unique pairwise connections from an adjacency dictionary efficiently.

Parameters
graph (dict):

An adjacency dictionary where keys are nodes and values are arrays or lists of nodes representing neighbors.

returns
list of tuple

A list of unique pairwise connections, each represented as a tuple (i, j) where i < j.

gjf_coordinate(filename)[source]

Reading coordinates from a gaussian .gjf file

parameter

filename (string) : Name of file containing the coordinate

Returns

coords : List of coordinate strings lattice (list) : List of lattice vectors strings

list_train_test_split(data, train_size=0.8, random_seed=42, shuffle=True)[source]

A function that take Splits a list into train and test sets based on the specified train_size.

parameter

data (list): The input list to split. train_size (float): The proportion of the data to be used as the training set (default is 0.8). random_seed (int, optional): A random seed for reproducibility (default is None). shuffle (bool): Whether to shuffle the data before splitting (default is True).

return

train_data: indices of data to be selected for training. test_data: indices of data to be selected for testing.

load_data_as_ase(filename)[source]

Load data as an ase atoms object parameter

filename (string)Any file type that has been defined in this module

including ase readable filetypes

return

ase_atoms : ase atoms object

prepare_dataset(ase_obj, energy)[source]

Prepares a dataset from ASE Atoms objects and their corresponding energy values.

parameters

ase_obj (ASE.Atoms): ASE Atoms object. energy (float): Energy value of the crystal structure.

returns

torch_geometric.data.Data: PyTorch Geometric Data object with input features, edge indices, and energy value.

pytorch_geometric_to_ase(data)[source]

Convert a PyTorch Geometric Data object back to an ASE Atoms object.

Parameters

data (torch_geometric.data.Data): The PyTorch Geometric Data object.

Returns

ase_atoms (ase.Atoms): The converted ASE Atoms object.

qchemcout(filename)[source]

Read coordinates from qchem output file

parameter

filename (string) : Name of file containing the coordinate

Returns

coords : List of coordinate strings

qchemin(filename)[source]

Read coordinates from qchem input file

parameter

filename (string) : filename

Returns

coords : list of coordinate strings

read_and_return_ase_atoms(filename)[source]

Function to read the ase atoms

parameter

filename: string

scm_out(qcin)[source]

Extract coordinates from scm output files

parameter

qcin (string) : scm output file

return

coords (list) : list of coordinates lattice_coords (list) : list of lattice coordinates

write_ase_atoms(ase_atoms, filename)[source]

Function to write the ase atoms

parameter

ase_atoms: ase.Atoms object filename: string

xtb_input(filename)[source]

Creating a gfn-xtb input file from any ase readable filetype or filetype that can be read by this module.

parameter

filename (string) : Any file type that has been defined in this module

return

xtb_coords : list of strings containing xtb input

xyz_coordinates(filename)[source]

Read any xyz coordinate file

parameter

filename (string) : Name of file containing the coordinate

Returns

coords : List of coordinate strings

read_write.filetyper

class AtomsEncoder(*, skipkeys=False, ensure_ascii=True, check_circular=True, allow_nan=True, sort_keys=False, indent=None, separators=None, default=None)[source]

Bases: JSONEncoder

Custom JSON encoder for serializing ASE Atoms objects and related data.

This encoder converts ASE Atoms objects into JSON-serializable dictionaries. It also handles the serialization of ASE Spacegroup objects.

Methods default(obj)

Serializes objects that are instances of ASE Atoms or Spacegroup, or falls back to the default JSON encoder for unsupported types.

Examples
>>> from ase import Atoms
>>> import json
>>> atoms = Atoms('H2O', positions=[[0, 0, 0], [0, 0.76, 0], [0.76, 0, 0]])
>>> json_data = json.dumps(atoms, cls=AtomsEncoder)
>>> print(json_data)
default(encorder_obj)[source]

define different encoder to serialise ase atom objects

append_contents(filename, output)[source]

Appends a list of strings to a file. This function appends the content of a list to a file, where each element in the list represents a line to be written. If the file does not exist, it will be created.

parameters
filenamestr

The path to the file where the content will be appended.

outputlist of str

A list of strings to be appended to the file. Each string represents a line, and newline characters should be included if needed.

append_json(new_data, filename)[source]

Appends new data to an existing JSON file. If the file does not exist or is empty, it creates a new JSON file with an empty dictionary. The function then updates the file with the provided data, overwriting existing keys if they are already present.

parameters
new_datadict

A dictionary containing the new data to append to the JSON file.

filenamestr

The path to the JSON file.

append_json_atom(data, filename)[source]

Appends or updates a JSON file with data containing an ASE Atoms object. If the file does not exist or is empty, it creates a new JSON file with an empty dictionary as the initial content. The function then updates the file with the provided data using the custom AtomsEncoder for serializing ASE Atoms objects.

parameters
datadict

A dictionary containing data with an ASE Atoms object or other serializable content.

filenamestr

The path to the JSON file where the data will be appended.

append_pickle(new_data, filename)[source]

Appends new data to a pickle file. This function appends new data to an existing pickle file. If the file does not exist, it will be created. Data is appended in binary format, ensuring that previously stored data is not overwritten.

parameters
new_dataobject

The Python object to append to the pickle file.

filenamestr

The path to the pickle file where the data will be appended.

combine_json_files(file1_path, file2_path, output_path)[source]

Queries data from a CSV (as a DataFrame) or JSON (as a dictionary).

This function retrieves data based on a reference key or value from either a dictionary (JSON-like object) or a pandas DataFrame (CSV-like object).

parameters
refstr or int

The reference key or value to query.

data_objectdict or pandas.DataFrame

The data source, which can be a dictionary (for JSON) or a pandas DataFrame (for CSV).

colstr, optional

The column name to query in the DataFrame. This parameter is required if the data source is a DataFrame and ignored if the data source is a dictionary.

returns
object

The queried data. For a dictionary, it returns the value associated with the reference key. For a DataFrame, it returns the rows where the specified column matches the reference value.

csv_read(csv_file)[source]

Reads a CSV file and returns its content as a list of rows. This function reads the content of a CSV file and returns it as a list.

parameters
csv_filestr

The path to the CSV file to be read.

returns
list of list of str

A list of rows from the CSV file. Each row is a list of strings.

get_contents(filename)[source]

Reads the content of a file and returns it as a list of lines. This function opens a file, reads its content line by line, and returns a list where each element is a line from the file, including newline characters.

parameters
filenamestr

The path to the file to be read.

returns
list of str

A list containing all lines in the file.

get_section(contents, start_key, stop_key, start_offset=0, stop_offset=0)[source]

Extracts a section of lines from a list of strings between specified start and stop keys. This function searches through a list of strings (e.g., file contents) to find the last occurrence of a start key and extracts all lines up to and including the first occurrence of a stop key, with optional offsets for flexibility.

parameters
contentslist of str

A list of strings representing the lines of a file or text content.

start_keystr

The key string that marks the start of the section.

stop_keystr

The key string that marks the end of the section.

start_offsetint, optional

The number of lines to include before the start key. Default is 0.

stop_offsetint, optional

The number of lines to include after the stop key. Default is 0.

returns
list of str

The extracted lines from contents between the start and stop keys, including the offsets.

json_to_aseatom(data, filename)[source]

Serialize an ASE Atoms object and write it to a JSON file. This function uses the custom AtomsEncoder to convert an ASE Atoms object into a JSON format and writes the serialized data to the specified file.

parameters
dataAtoms or dict

The ASE Atoms object or dictionary to serialize.

filenamestr

The path to the JSON file where the serialized data will be saved.

json_to_numpy(json_file)[source]

Deserializes a JSON file containing a NumPy array back into a NumPy array. This function reads a JSON file, deserializes the data, and converts it into a NumPy array.

parameters
json_filestr

The path to the JSON file containing the serialized NumPy array.

returns
numpy.ndarray

The deserialized NumPy array.

list_2_json(list_obj, file_name)[source]

Writes a list to a JSON file. This function serializes a Python list and saves it to a specified JSON file.

parameters
list_objlist

The list to serialize and write to the file.

file_namestr

The path to the JSON file where the list will be saved.

load_data(filename)[source]

Automatically detects the file extension and loads the data using the appropriate function. This function reads a file and returns its content, choosing the correct loading method based on the file extension. Supported file formats include JSON, CSV, Pickle, Excel, and plain text files.

parameters
filenamestr

The path to the file to be loaded.

returns
object

The loaded data, which can be a dictionary, DataFrame, list, or other Python object, depending on the file type.

load_dict_msgpack(filename: str) dict[source]

Load a dictionary from a MessagePack file.

numpy_to_json(ndarray, file_name)[source]

Serializes a NumPy array and saves it to a JSON file. This function converts a NumPy array into a list format, which is JSON-serializable, and writes it to the specified file.

parameters
ndarraynumpy.ndarray

The NumPy array to serialize.

file_namestr

The path to the JSON file where the serialized data will be saved.

pickle_load(filename)[source]

Loads and deserializes data from a pickle file. This function reads a pickle file and deserializes its content into a Python object.

parameters
filenamestr

The path to the pickle file to be loaded.

returns
object

The deserialized Python object from the pickle file

put_contents(filename, output)[source]

Writes a list of strings into a file. This function writes the content of a list to a file, where each element in the list represents a line to be written. If the file already exists, it will be overwritten.

parameters
filenamestr

The path to the file where the content will be written.

outputlist of str

A list of strings to be written to the file. Each string represents a line, and newline characters should be included if needed.

query_data(ref, data_object, col=None)[source]

Queries data from a CSV (as a DataFrame) or JSON (as a dictionary).

This function retrieves data based on a reference key or value from either a dictionary (JSON-like object) or a pandas DataFrame (CSV-like object).

parameters
refstr or int

The reference key or value to query.

data_objectdict or pandas.DataFrame

The data source, which can be a dictionary (for JSON) or a pandas DataFrame (for CSV).

colstr, optional

The column name to query in the DataFrame. This parameter is required if the data source is a DataFrame and ignored if the data source is a dictionary.

returns
object

The queried data. For a dictionary, it returns the value associated with the reference key. For a DataFrame, it returns the rows where the specified column matches the reference value.

read_json(file_name)[source]

Loads and reads a JSON file. This function opens a JSON file, reads its content, and deserializes it into a Python object (e.g., a dictionary or list).

Parameters
file_namestr

The path to the JSON file to be read.

returns
dict or list

The deserialized content of the JSON file.

read_zip(zip_file)[source]

Reads and extracts the contents of a zip file.

This function opens a zip file and extracts its contents to the specified directory. If no directory is provided, it extracts to the current working directory.

parameters
zip_filestr

The path to the zip file to be read and extracted.

extract_tostr, optional

The directory where the contents of the zip file will be extracted. If not provided, the current working directory is used.

returns
list of str

A list of file names contained in the zip file.

remove_trailing_commas(json_file)[source]

Cleans trailing commas in a JSON file and returns the cleaned JSON string. This function reads a JSON file, removes trailing commas from objects and arrays, and returns the cleaned JSON string. It is useful for handling improperly formatted JSON files with trailing commas that are not compliant with the JSON standard.

parameters
json_filestr

The path to the JSON file to be cleaned.

returns
cleaned_json str

A cleaned JSON string with trailing commas removed.

save_dict_msgpack(data: dict, filename: str) None[source]

Save a dictionary to a file using MessagePack.

save_pickle(model, file_path)[source]

Saves a Python object to a file using pickle. This function serializes a Python object and saves it to a specified file in binary format using the pickle module.

parameters
modelobject

The Python object to serialize and save.

file_pathstr

The path to the file where the object will be saved.

write_json(json_obj, file_name)[source]

Writes a Python dictionary object to a JSON file.

This function serializes a Python dictionary into JSON format and writes it to the specified file and ensures that the JSON is human-readable with proper indentation.

parameters
json_objdict

The Python dictionary to serialize and write to the JSON file.

file_namestr

The path to the JSON file where the data will be saved.

read_write.cheminfo2iupac

main()[source]

Main function to parse command-line arguments, retrieve chemical information from PubChem, and write (or append) the data to a CSV file.

print_helpful_information()[source]

Prints helpful information about using the chemical_parser script.

pubchem_to_inchikey(identifier, name='smiles')[source]

A function that retrieves chemical properties from PubChem using the given identifier.

This function queries the PubChem database via the PubChemPy package and extracts a set of cheminformatic properties for the first matching compound. The properties include InChIKey, CID, IUPAC name, canonical SMILES, hydrogen bond donor and acceptor counts, rotatable bond count, and charge.

Parameters:
  • identifier (str) – The chemical identifier to query (e.g., chemical name or SMILES).

  • search_type (str) – The type of identifier provided. Options include ‘name’ or ‘smiles’. Default is ‘name’.

Returns:

A dictionary containing the chemical properties if a matching compound is found;

otherwise, None.

Return type:

dict or None

read_write.struct2iupac

file2smile(filename)[source]

Function that reads a filename and uses openbabel to compute smi, inChi and inChiKey Parameters:

filename (str): name of file containing the structure. It can be in any ase readable file format as well as qchem out, AMS out, Gaussian out.

Return:

smi (str): SMILE strings inchi : inchi hashing of structure inChiKey : 28 character hashing

main()[source]

Main function to parse command-line arguments, retrieve chemical information from PubChem, and write (or append) the data to a CSV file.

print_helpful_information()[source]

Prints helpful information about using the chemical_parser script.

pubchem_to_inchikey(filename)[source]

A function that retrieves chemical properties from PubChem using the given identifier.

This function queries the PubChem database via the PubChemPy package and extracts a set of cheminformatic properties for the first matching compound. The properties include InChIKey, CID, IUPAC name, canonical SMILES, hydrogen bond donor and acceptor counts, rotatable bond count, and charge.

Parameters:
  • identifier (str) – The chemical identifier to query (e.g., chemical name or SMILES).

  • search_type (str) – The type of identifier provided. Options include ‘name’ or ‘smiles’. Default is ‘name’.

Returns:

A dictionary containing the chemical properties if a matching compound is found;

otherwise, None.

Return type:

dict or None

Graph Deep learning models

models for predicting Bond dissociations enthalpy

model.thermodynamic_stability

class EarlyStopping(patience=5, delta=0.001, path='checkpoint.pth')[source]

Bases: object

Implements early stopping to terminate training when validation loss stops improving.

patience

Number of epochs/iteractions to wait for an improvement in validation loss before stopping. Default is 5.

Type:

int

delta

Minimum improvement in validation loss required to reset the patience counter. Default is 0.001.

Type:

float

best_loss

The lowest validation loss observed so far. Initialized to infinity.

Type:

float

counter

Tracks the number of consecutive epochs without improvement in validation loss.

Type:

int

early_stop

Flag indicating whether early stopping condition is met.

Type:

bool

path

File path to save the best model checkpoint. Default is ‘checkpoint.pth’.

Type:

str

__call__(val_loss, model)[source]

Evaluates the validation loss and decides whether to stop training or save a checkpoint.

save_checkpoint(model)[source]

Saves the model’s state dictionary to the specified checkpoint file.

save_checkpoint(model, optimizer, normalise_parameter)[source]

Saves the current model state to the checkpoint file.

Parameters:
  • model (torch.nn.Module) – The PyTorch model being trained.

  • optimizer (optim.Optimizer) – Optimizer for updating model parameters.

  • normalise_parameter (float) – Normalization parameter for target values.

Behavior:
  • Saves the model’s state dictionary to the file specified by path.

class EnergyGNN_GAT2Conv(input_dim, hidden_dim, output_dim, edge_dim, heads=4, dropout=0.2)[source]

Bases: Module

A Graph Neural Network class for predicting the thermodynamic stability of MOFs using Graph Attention Networks (GATv2).

Arg:

input_dim (int): Number of input node features. hidden_dim (int): Number of hidden units in the GATv2 layers. output_dim (int): Number of output units (e.g., 1 for regression). heads (int, optional): Number of attention heads. Default is 1. dropout (float, optional): Dropout rate. Default is 0.2.

forward(data)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class EnergyGNN_GIN(input_dim, hidden_dim, output_dim, edge_dim=None, heads=None, dropout=0.2)[source]

Bases: Module

Implements a Graph Neural Network (GNN) using Graph Isomorphism Network (GIN) layers for graph-level prediction tasks such as regression or classification.

input_dim

Dimensionality of input node features.

Type:

int

hidden_dim

Dimensionality of hidden layers.

Type:

int

output_dim

Dimensionality of the output.

Type:

int

dropout

Dropout rate for regularization. Default is 0.2.

Type:

float

Layers:
  • Three GINConv layers, each using an MLP for message passing.

  • Batch normalization after each GINConv layer.

  • Global mean pooling for aggregating node-level features to graph-level.

  • Fully connected layers for output prediction.

forward(data)[source]

Performs a forward pass of the model on input graph data.

forward(data)[source]

Defines the forward pass for the model.

Parameters:

data (torch_geometric.data.Data) – Input graph data object containing: - x (torch.Tensor): Node features of shape [num_nodes, input_dim]. - edge_index (torch.Tensor): Edge indices of shape [2, num_edges]. - edge_attr (torch.Tensor, optional): Edge features (not used here). - batch (torch.Tensor): Batch indices for mini-batch training.

Returns:

Final output predictions of shape [num_graphs, output_dim].

Return type:

torch.Tensor

EnergyGNN_GIN_main_model(path_to_lmdb, hidden_dim, learning_rate, batch_size, dropout, epoch, patience, print_every, save_path)[source]

Main function for training the EnergyGNN_GIN model.

parameters:

path_to_lmdb (str): Path to the LMDB file containing the dataset. hidden_dim (int): Dimensionality of the hidden layer. learning_rate (float): Learning rate for the optimizer. batch_size (int): Batch size for training. dropout (float): Dropout rate for the model. epoch (int): Number of training epochs. patience (int): Patience for early stopping. print_every (int): Frequency of printing training loss and validation loss. save_path (str): Path to save the best model.

**Bahaviors **

This function loads the dataset, trains the EnergyGNN_GIN model, and saves the best model.

GAT2Conv_main_model(path_to_lmdb, hidden_dim, learning_rate, batch_size, dropout, heads, epoch, patience, print_every, save_path)[source]

Main function for training the GAT2Conv model.

parameters:

path_to_lmdb (str): Path to the LMDB file containing the dataset. hidden_dim (int): Dimensionality of the hidden layer. learning_rate (float): Learning rate for the optimizer. batch_size (int): Batch size for training. dropout (float): Dropout rate for the model. heads (int): Number of attention heads for the GAT layer. epoch (int): Number of training epochs. patience (int): Patience for early stopping. print_every (int): Frequency of printing training loss and validation loss. save_path (str): Path to save the best model.

**Bahaviors **

This function loads the dataset, trains the GAT2Conv model, and saves the best model.

class GraphLatticeModel(input_dim, gnn_hidden_dim, lattice_hidden_dim, output_dim, num_layers=3, dropout=0.2)[source]

Bases: Module

A Graph Neural Network (GNN) model that incorporates both graph-based features and lattice matrix features for graph-level prediction tasks such as property regression or classification.

input_dim

Dimensionality of input node features.

Type:

int

gnn_hidden_dim

Dimensionality of GNN hidden layers.

Type:

int

lattice_hidden_dim

Dimensionality of hidden layers for lattice matrix processing.

Type:

int

output_dim

Dimensionality of the output (e.g., target properties).

Type:

int

num_layers

Number of GNN layers. Default is 3.

Type:

int

dropout

Dropout rate for regularization. Default is 0.2.

Type:

float

forward(data)[source]

Performs a forward pass through the model using the input graph and lattice data.

forward(data)[source]

Performs a forward pass of the model.

Parameters:

data (torch_geometric.data.Data) – Input data containing: - x (torch.Tensor): Node features of shape [num_nodes, input_dim]. - edge_index (torch.Tensor): Edge indices of shape [2, num_edges]. - edge_attr (torch.Tensor, optional): Edge features (not used here). - batch (torch.Tensor): Batch indices for mini-batch training. - lattice (torch.Tensor): Lattice matrix of shape [num_graphs, 3, 3].

Returns:

Final predictions of shape [num_graphs, output_dim].

Return type:

torch.Tensor

GraphLatticeModel_main_model(path_to_lmdb, hidden_dim, lattice_hidden_dim, num_layers, learning_rate, batch_size, dropout, epoch, patience, print_every, save_path)[source]

Main function for training the EnergyGNN_GIN model.

parameters:

path_to_lmdb (str): Path to the LMDB file containing the dataset. hidden_dim (int): Dimensionality of the hidden layer. learning_rate (float): Learning rate for the optimizer. batch_size (int): Batch size for training. dropout (float): Dropout rate for the model. epoch (int): Number of training epochs. patience (int): Patience for early stopping. print_every (int): Frequency of printing training loss and validation loss. save_path (str): Path to save the best model.

**Bahaviors **

This function loads the dataset, trains the EnergyGNN_GIN model, and saves the best model.

display_model_helper(model)[source]

Display clear and concise information about the selected model and its key parameters.

Parameters:

model (str) – The name of the model. Options are: - “EnergyGNN_GAT2Conv”: Graph Attention Networks (GATv2) based model. - “EnergyGNN_GIN”: Graph Isomorphism Network (GIN) based model. - “GraphLatticeModel”: Combines graph features with lattice matrix processing.

Instructions:
  • Review the description for each model.

  • Identify the required parameters for your use case.

  • Ensure you provide the necessary inputs when configuring the model.

Example Usage:

display_model_helper(“EnergyGNN_GIN”)

entry_point()[source]

Entry point for the training script. Dynamically calls the appropriate main function based on the selected model.

evaluate(model, dataloader, criterion, device)[source]

Evaluate the model using the given data and loss function, and compute accuracy.

parameters:

model (nn.Module): The trained model to evaluate. dataloader (DataLoader): DataLoader for batching the dataset during evaluation. criterion (nn.Module): Loss function (e.g., CrossEntropyLoss) to compute the evaluation loss. device (torch.device): The device (CPU or GPU) for computation.

returns:
tuple: A tuple containing the average evaluation loss and accuracy.

(average_loss, accuracy)

inverse_normalize(predictions, normalization_params, method='z-score')[source]

Inverse the normalization of predictions to the original scale.

Parameters:
  • predictions (Tensor) – The normalized predictions.

  • normalization_params (dict) – The parameters used for normalization.

  • method (str) – The normalization method (‘min-max’ or ‘z-score’).

Returns:

Predictions in the original scale.

Return type:

Tensor

load_dataset(path_to_lmdb, batch_size, train_size=0.9, random_seed=42, shuffle=True, normalize='full')[source]

Loads a dataset from an LMDB file and splits it into training, validation, and test sets.

The function uses the coords_library.LMDBDataset to load the dataset and splits it into training and test datasets. The training dataset is further split into training and validation sets. Data loaders are created for the training and validation datasets.

Parameters:
  • path_to_lmdb (str) – Path to the LMDB file containing the dataset.

  • batch_size (int) – Batch size for the DataLoader.

  • train_size (float, optional) – Fraction of the data to use for training. The rest is used for testing. Default is 0.8.

  • random_seed (int, optional) – Random seed for splitting the data. Ensures reproducibility. Default is 42.

  • shuffle (bool, optional) – Whether to shuffle the data before splitting. Default is True.

  • normalize (str, optional) – Normalization method to use. Can be ‘full’ for full normalization or ‘batch’ for batch normalization. Default is ‘full’.

Returns:

  • train_loader (DataLoader): DataLoader for the training dataset.

  • val_loader (DataLoader): DataLoader for the validation dataset.

  • test_dataset (Dataset): Dataset object containing the test data.

Return type:

tuple

load_model(path='model.pth', device='cpu')[source]

Load a saved model and optimizer state from a file.

parameters:

path (str, optional): Path to load the model. Default is “model.pth”. device (torch.device, optional): The device (CPU or GPU) for computation. Default is “cpu”.

returns:

tuple: The loaded model and optimizer.

normalize_data(dataset, method='z-score')[source]

Normalize the target values (data.y) in the dataset.

Parameters:
  • dataset (Dataset) – The dataset object containing the data.

  • method (str) – The normalization method (‘min-max’ or ‘z-score’).

Returns:

Normalized dataset and the normalization parameters.

Return type:

tuple

parse_arguments()[source]

Parse command-line arguments for training the GNN model.

Returns:

Parsed command-line arguments with detailed descriptions.

Return type:

argparse.Namespace

save_model(model, optimizer, normalise_parameter, path='model.pth')[source]

Save the trained model and optimizer state to a file.

parameters:

model (nn.Module): The trained GNN model to save. optimizer (optim.Optimizer): Optimizer for updating model parameters. path (str, optional): Path to save the model. Default is “model.pth”.

train(model, dataloader, optimizer, criterion, device)[source]

Train the model using the given data and optimizer.

parameters:

model (nn.Module): The GNN model to train dataloader (DataLoader): DataLoader for batching the dataset during training. optimizer (optim.Optimizer): Optimizer for updating model parameters. criterion (nn.Module): Loss function (e.g., MSELoss) to compute the training loss. device (torch.device): The device (CPU or GPU) for computation.

returns:

float: The average training loss over the epoch.

transform_target(test_data, normalize_param, method='z-score')[source]

Transform the target values in the test dataset according to the normalization parameters.

Parameters:
  • test_data (Dataset) – The test dataset containing the target values.

  • normalize_param (dict) – The normalization parameters (mean and std).

  • method (str, optional) – The normalization method (‘z-score’ or’min-max’). Default is ‘z-score’.

Returns:

The transformed test dataset.

Return type:

Dataset

model.hyper_optimiser

fine_opt_paramter(path_to_lmbd, mol_def)[source]

Fine-tuning the optimization parameters using Optuna.

Parameters:
  • path_to_lmbd (str) – Path to the lambda file.

  • mol_def (str) – Molecule definition (‘GAT’ or ‘GIN’).

Returns:

Validation loss.

Return type:

float

main()[source]