# -*- coding: utf-8 -*-
"""Interface for creating a custom StructureDataset."""
from pathlib import Path
from typing import Collection, List, Optional
import numpy as np
import pandas as pd
from loguru import logger
from mofdscribe.datasets.checks import check_all_file_exists
from mofdscribe.datasets.dataset import AbstractStructureDataset
from mofdscribe.datasets.utils import compress_dataset
from mofdscribe.types import PathType
__all__ = ["StructureDataset", "FrameDataset"]
[docs]class StructureDataset(AbstractStructureDataset):
"""Custom dataset class for loading structures from a files"""
def __init__(
self,
files: Collection[PathType],
df: Optional[pd.DataFrame] = None,
structure_name_column: Optional[str] = None,
year_column: Optional[str] = None,
label_columns: Optional[List[str]] = None,
feature_columns: Optional[List[str]] = None,
decorated_graph_hash_column: Optional[str] = None,
undecorated_graph_hash_column: Optional[str] = None,
decorated_scaffold_hash_column: Optional[str] = None,
undecorated_scaffold_hash_column: Optional[str] = None,
density_column: Optional[str] = None,
):
"""Initialize a dataset.
Args:
files (Collection[PathType]): List of files to load structures from.
df (Optional[pd.DataFrame], optional): Dataframe containing the structures.
Defaults to None.
structure_name_column (str): Name of the column containing the structure names.
Defaults to None.
year_column (str, optional): Name of the column containing the year of the structure.
Defaults to None.
label_columns (Optional[List[str]], optional): List of columns containing the labels.
Defaults to None.
feature_columns (Optional[List[str]], optional): List of columns containing the features.
Defaults to None.
decorated_graph_hash_column (str, optional): Name of the column containing the decorated graph hash.
Defaults to None.
undecorated_graph_hash_column (str, optional): Name of the column containing the undecorated graph hash.
Defaults to None.
decorated_scaffold_hash_column (str, optional): Name of the column containing the decorated scaffold hash.
Defaults to None.
undecorated_scaffold_hash_column (str, optional): Name of the column containing
the undecorated scaffold hash.
Defaults to None.
density_column (str, optional): Name of the column containing the density of the structure.
Defaults to None.
"""
super().__init__()
self._df = df
if self._df is not None:
compress_dataset(self._df)
self._structures = [
f for f in files if Path(f).stem in self._df[structure_name_column].values
]
else:
self._structures = files
self._structure_name_column = structure_name_column
check_all_file_exists(self._structures)
self._year_column = year_column
self._label_columns = list(label_columns) if label_columns is not None else ()
self._feature_columns = list(feature_columns) if feature_columns is not None else ()
self._decorated_graph_hash_column = decorated_graph_hash_column
self._undecorated_graph_hash_column = undecorated_graph_hash_column
self._decorated_scaffold_hash_column = decorated_scaffold_hash_column
self._undecorated_scaffold_hash_column = undecorated_scaffold_hash_column
self._density_column = density_column
self._years = None if year_column is None else self._df[year_column].values
self._labels = None if label_columns is None else self._df[label_columns].values
self._decorated_graph_hashes = (
None
if decorated_graph_hash_column is None
else self._df[decorated_graph_hash_column].values
)
self._undecorated_graph_hashes = (
None
if undecorated_graph_hash_column is None
else self._df[undecorated_graph_hash_column].values
)
self._decorated_scaffold_hashes = (
None
if decorated_scaffold_hash_column is None
else self._df[decorated_scaffold_hash_column].values
)
self._undecorated_scaffold_hashes = (
None
if undecorated_scaffold_hash_column is None
else self._df[undecorated_scaffold_hash_column].values
)
self._densities = None if density_column is None else self._df[density_column].values
def __len__(self):
"""Return number of structures in the dataset."""
return len(self._structures)
@property
def available_features(self) -> List[str]:
return self._featurenames
@property
def available_labels(self) -> List[str]:
return self._labelnames
def get_labels(self, idx: Collection[int], labelnames: Collection[str] = None) -> np.ndarray:
labelnames = labelnames if labelnames is not None else self._labelnames
return self._df.iloc[idx][list(labelnames)].values
[docs] @classmethod
def from_folder_and_dataframe(
cls,
folder: PathType,
extension: str = "cif",
dataframe: Optional[pd.DataFrame] = None,
structure_name_column: Optional[str] = None,
year_column: Optional[str] = None,
label_columns: Optional[List[str]] = None,
decorated_graph_hash_column: Optional[str] = None,
undecorated_graph_hash_column: Optional[str] = None,
decorated_scaffold_hash_column: Optional[str] = None,
undecorated_scaffold_hash_column: Optional[str] = None,
density_column: Optional[str] = None,
) -> "StructureDataset":
"""Create a dataset from a folder and a dataframe.
Args:
folder (PathType): Path to the folder containing the structures.
extension (str): Extension of the files. Defaults to 'cif'.
dataframe (Optional[pd.DataFrame], optional): Dataframe containing the structures.
Defaults to None.
structure_name_column (str): Name of the column containing the structure names.
Defaults to None.
year_column (str, optional): Name of the column containing the year of the structure.
Defaults to None.
label_columns (Optional[List[str]], optional): List of columns containing the labels.
Defaults to None.
decorated_graph_hash_column (str, optional): Name of the column containing the decorated graph hash.
Defaults to None.
undecorated_graph_hash_column (str, optional): Name of the column containing the undecorated graph hash.
Defaults to None.
decorated_scaffold_hash_column (str, optional): Name of the column containing the decorated scaffold hash.
Defaults to None.
undecorated_scaffold_hash_column (str, optional): Name of the column containing the undecorated scaffold
hash. Defaults to None.
density_column (str, optional): Name of the column containing the density of the structure.
Defaults to None.
Returns:
StructureDataset: Dataset containing the structures.
"""
all_files = list(Path(folder).rglob(f"*.{extension}"))
return cls(
all_files,
dataframe,
structure_name_column,
year_column,
label_columns,
decorated_graph_hash_column,
undecorated_graph_hash_column,
decorated_scaffold_hash_column,
undecorated_scaffold_hash_column,
density_column,
)
[docs]class FrameDataset(AbstractStructureDataset):
"""Dataset containing structure information read from a dataframe."""
def __init__(
self,
df: pd.DataFrame,
structure_name_column: str,
year_column: Optional[str] = None,
label_columns: Optional[List[str]] = None,
decorated_graph_hash_column: Optional[str] = None,
undecorated_graph_hash_column: Optional[str] = None,
decorated_scaffold_hash_column: Optional[str] = None,
undecorated_scaffold_hash_column: Optional[str] = None,
density_column: Optional[str] = None,
):
"""Initialize the dataset.
Args:
df (pd.DataFrame): Dataframe containing the structures.
structure_name_column (str): Name of the column containing the structure names.
year_column (str, optional): Name of the column containing the year of the structure.
Defaults to None.
label_columns (Optional[List[str]], optional): List of columns containing the labels.
Defaults to None.
decorated_graph_hash_column (str, optional): Name of the column containing the decorated graph hash.
Defaults to None.
undecorated_graph_hash_column (str, optional): Name of the column containing the undecorated graph hash.
Defaults to None.
decorated_scaffold_hash_column (str, optional): Name of the column containing the decorated scaffold hash.
Defaults to None.
undecorated_scaffold_hash_column (str, optional): Name of the column containing the undecorated scaffold
hash. Defaults to None.
density_column (str, optional): Name of the column containing the density of the structure.
Defaults to None.
"""
super().__init__()
logger.warning(
"FrameDataset support is experimental. Some splitter integrations may not work."
)
self._df = df
compress_dataset(self._df)
self._structure_name_column = structure_name_column
self._year_column = year_column
self._label_columns = list(label_columns) if label_columns is not None else ()
self._decorated_graph_hash_column = decorated_graph_hash_column
self._undecorated_graph_hash_column = undecorated_graph_hash_column
self._decorated_scaffold_hash_column = decorated_scaffold_hash_column
self._undecorated_scaffold_hash_column = undecorated_scaffold_hash_column
self._density_column = density_column
self._years = None if year_column is None else self._df[year_column]
self._labels = None if label_columns is None else self._df[label_columns].values
self._decorated_graph_hashes = (
None
if decorated_graph_hash_column is None
else self._df[decorated_graph_hash_column].values
)
self._undecorated_graph_hashes = (
None
if undecorated_graph_hash_column is None
else self._df[undecorated_graph_hash_column].values
)
self._decorated_scaffold_hashes = (
None
if decorated_scaffold_hash_column is None
else self._df[decorated_scaffold_hash_column].values
)
self._undecorated_scaffold_hashes = (
None
if undecorated_scaffold_hash_column is None
else self._df[undecorated_scaffold_hash_column].values
)
self._densities = None if density_column is None else self._df[density_column].values
def __len__(self):
"""Return number of structures in the dataset."""
return len(self._df)
[docs] def get_subset(self, indices: Collection[int]) -> "FrameDataset":
"""Get a subset of the dataset.
Args:
indices (Collection[int]): indices of the structures to include.
Returns:
FrameDataset: a new dataset containing only the structures
specified by the indices.
"""
return FrameDataset(
self._df.iloc[indices],
structure_name_column=self._structure_name_column,
year_column=self._year_column,
label_columns=self._label_columns,
decorated_graph_hash_column=self._decorated_graph_hash_column,
undecorated_graph_hash_column=self._undecorated_graph_hash_column,
decorated_scaffold_hash_column=self._decorated_scaffold_hash_column,
undecorated_scaffold_hash_column=self._undecorated_scaffold_hash_column,
density_column=self._density_column,
)
@property
def available_features(self) -> List[str]:
return self._featurenames
@property
def available_labels(self) -> List[str]:
return self._labelnames
def get_labels(self, idx: Collection[int], labelnames: Collection[str] = None) -> np.ndarray:
labelnames = labelnames if labelnames is not None else self._labelnames
return self._df.iloc[idx][list(labelnames)].values