Skip to content

Format loading

GeneralStructure

Bases: ABC

Abstract Factory class for datasets type in the openQDC package.

Source code in openqdc/datasets/structure.py
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
class GeneralStructure(ABC):
    """
    Abstract Factory class for datasets type in the openQDC package.
    """

    _ext: Optional[str] = None
    _extra_files: Optional[List[str]] = None

    @property
    def ext(self):
        return self._ext

    @property
    @abstractmethod
    def load_fn(self) -> Callable:
        """
        Function to use for loading the data.
        Must be implemented by the child class.

        Returns:
            the function to use for loading the data
        """
        raise NotImplementedError

    def add_extension(self, filename: str) -> str:
        """
        Add the correct extension to a filename

        Parameters:
            filename:  the filename to add the extension to

        Returns:
            the filename with the extension
        """
        return filename + self.ext

    @abstractmethod
    def save_preprocess(
        self,
        preprocess_path: Union[str, PathLike],
        data_keys: List[str],
        data_dict: Dict[str, np.ndarray],
        extra_data_keys: List[str],
        extra_data_types: Dict[str, type],
    ) -> List[str]:
        """
        Save the preprocessed data to the cache directory and optionally upload it to the remote storage.
        Must be implemented by the child class.

        Parameters:
            preprocess_path:  path to the preprocessed data file
            data_keys:        list of keys to load from the data file
            data_dict:        dictionary of data to save
            extra_data_keys:  list of keys to load from the extra data file
            extra_data_types: dictionary of data types for each key
        """
        raise NotImplementedError

    @abstractmethod
    def load_extra_files(
        self,
        data: Dict[str, np.ndarray],
        preprocess_path: Union[str, PathLike],
        data_keys: List[str],
        pkl_data_keys: List[str],
        overwrite: bool,
    ):
        """
        Load extra files required to define other types of data.
        Must be implemented by the child class.

        Parameters:
            data:  dictionary of data to load
            preprocess_path:  path to the preprocessed data file
            data_keys:    list of keys to load from the data file
            pkl_data_keys:   list of keys to load from the extra files
            overwrite:   whether to overwrite the local cache
        """
        raise NotImplementedError

    def join_and_ext(self, path: Union[str, PathLike], filename: str) -> Union[str, PathLike]:
        """
        Join a path and a filename and add the correct extension.

        Parameters:
            path:  the path to join
            filename:  the filename to join

        Returns:
            the joined path with the correct extension
        """
        return p_join(path, self.add_extension(filename))

    def load_data(
        self,
        preprocess_path: Union[str, PathLike],
        data_keys: List[str],
        data_types: Dict[str, np.dtype],
        data_shapes: Dict[str, Tuple[int, int]],
        extra_data_keys: List[str],
        overwrite: bool,
    ):
        """
        Main method to load the data from a filetype structure like memmap or zarr.

        Parameters:
            preprocess_path:  path to the preprocessed data file
            data_keys:        list of keys to load from the data file
            data_types:       dictionary of data types for each key
            data_shapes:      dictionary of shapes for each key
            extra_data_keys:  list of keys to load from the extra data file
            overwrite:        whether to overwrite the local cache
        """
        data = {}
        for key in data_keys:
            filename = self.join_and_ext(preprocess_path, key)
            pull_locally(filename, overwrite=overwrite)
            data[key] = self.load_fn(filename, mode="r", dtype=data_types[key])
            data[key] = self.unpack(data[key])
            data[key] = data[key].reshape(*data_shapes[key])

        data = self.load_extra_files(data, preprocess_path, data_keys, extra_data_keys, overwrite)
        return data

    def unpack(self, data: any) -> any:
        """
        Unpack the data from the loaded file.

        Parameters:
            data:  the data to unpack

        Returns:
            the unpacked data
        """
        return data

load_fn: Callable abstractmethod property

Function to use for loading the data. Must be implemented by the child class.

Returns:

Type Description
Callable

the function to use for loading the data

add_extension(filename)

Add the correct extension to a filename

Parameters:

Name Type Description Default
filename str

the filename to add the extension to

required

Returns:

Type Description
str

the filename with the extension

Source code in openqdc/datasets/structure.py
37
38
39
40
41
42
43
44
45
46
47
def add_extension(self, filename: str) -> str:
    """
    Add the correct extension to a filename

    Parameters:
        filename:  the filename to add the extension to

    Returns:
        the filename with the extension
    """
    return filename + self.ext

join_and_ext(path, filename)

Join a path and a filename and add the correct extension.

Parameters:

Name Type Description Default
path Union[str, PathLike]

the path to join

required
filename str

the filename to join

required

Returns:

Type Description
Union[str, PathLike]

the joined path with the correct extension

Source code in openqdc/datasets/structure.py
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
def join_and_ext(self, path: Union[str, PathLike], filename: str) -> Union[str, PathLike]:
    """
    Join a path and a filename and add the correct extension.

    Parameters:
        path:  the path to join
        filename:  the filename to join

    Returns:
        the joined path with the correct extension
    """
    return p_join(path, self.add_extension(filename))

load_data(preprocess_path, data_keys, data_types, data_shapes, extra_data_keys, overwrite)

Main method to load the data from a filetype structure like memmap or zarr.

Parameters:

Name Type Description Default
preprocess_path Union[str, PathLike]

path to the preprocessed data file

required
data_keys List[str]

list of keys to load from the data file

required
data_types Dict[str, dtype]

dictionary of data types for each key

required
data_shapes Dict[str, Tuple[int, int]]

dictionary of shapes for each key

required
extra_data_keys List[str]

list of keys to load from the extra data file

required
overwrite bool

whether to overwrite the local cache

required
Source code in openqdc/datasets/structure.py
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
def load_data(
    self,
    preprocess_path: Union[str, PathLike],
    data_keys: List[str],
    data_types: Dict[str, np.dtype],
    data_shapes: Dict[str, Tuple[int, int]],
    extra_data_keys: List[str],
    overwrite: bool,
):
    """
    Main method to load the data from a filetype structure like memmap or zarr.

    Parameters:
        preprocess_path:  path to the preprocessed data file
        data_keys:        list of keys to load from the data file
        data_types:       dictionary of data types for each key
        data_shapes:      dictionary of shapes for each key
        extra_data_keys:  list of keys to load from the extra data file
        overwrite:        whether to overwrite the local cache
    """
    data = {}
    for key in data_keys:
        filename = self.join_and_ext(preprocess_path, key)
        pull_locally(filename, overwrite=overwrite)
        data[key] = self.load_fn(filename, mode="r", dtype=data_types[key])
        data[key] = self.unpack(data[key])
        data[key] = data[key].reshape(*data_shapes[key])

    data = self.load_extra_files(data, preprocess_path, data_keys, extra_data_keys, overwrite)
    return data

load_extra_files(data, preprocess_path, data_keys, pkl_data_keys, overwrite) abstractmethod

Load extra files required to define other types of data. Must be implemented by the child class.

Parameters:

Name Type Description Default
data Dict[str, ndarray]

dictionary of data to load

required
preprocess_path Union[str, PathLike]

path to the preprocessed data file

required
data_keys List[str]

list of keys to load from the data file

required
pkl_data_keys List[str]

list of keys to load from the extra files

required
overwrite bool

whether to overwrite the local cache

required
Source code in openqdc/datasets/structure.py
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
@abstractmethod
def load_extra_files(
    self,
    data: Dict[str, np.ndarray],
    preprocess_path: Union[str, PathLike],
    data_keys: List[str],
    pkl_data_keys: List[str],
    overwrite: bool,
):
    """
    Load extra files required to define other types of data.
    Must be implemented by the child class.

    Parameters:
        data:  dictionary of data to load
        preprocess_path:  path to the preprocessed data file
        data_keys:    list of keys to load from the data file
        pkl_data_keys:   list of keys to load from the extra files
        overwrite:   whether to overwrite the local cache
    """
    raise NotImplementedError

save_preprocess(preprocess_path, data_keys, data_dict, extra_data_keys, extra_data_types) abstractmethod

Save the preprocessed data to the cache directory and optionally upload it to the remote storage. Must be implemented by the child class.

Parameters:

Name Type Description Default
preprocess_path Union[str, PathLike]

path to the preprocessed data file

required
data_keys List[str]

list of keys to load from the data file

required
data_dict Dict[str, ndarray]

dictionary of data to save

required
extra_data_keys List[str]

list of keys to load from the extra data file

required
extra_data_types Dict[str, type]

dictionary of data types for each key

required
Source code in openqdc/datasets/structure.py
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
@abstractmethod
def save_preprocess(
    self,
    preprocess_path: Union[str, PathLike],
    data_keys: List[str],
    data_dict: Dict[str, np.ndarray],
    extra_data_keys: List[str],
    extra_data_types: Dict[str, type],
) -> List[str]:
    """
    Save the preprocessed data to the cache directory and optionally upload it to the remote storage.
    Must be implemented by the child class.

    Parameters:
        preprocess_path:  path to the preprocessed data file
        data_keys:        list of keys to load from the data file
        data_dict:        dictionary of data to save
        extra_data_keys:  list of keys to load from the extra data file
        extra_data_types: dictionary of data types for each key
    """
    raise NotImplementedError

unpack(data)

Unpack the data from the loaded file.

Parameters:

Name Type Description Default
data any

the data to unpack

required

Returns:

Type Description
any

the unpacked data

Source code in openqdc/datasets/structure.py
137
138
139
140
141
142
143
144
145
146
147
def unpack(self, data: any) -> any:
    """
    Unpack the data from the loaded file.

    Parameters:
        data:  the data to unpack

    Returns:
        the unpacked data
    """
    return data

MemMapDataset

Bases: GeneralStructure

Dataset structure for memory-mapped numpy arrays and props.pkl files.

Source code in openqdc/datasets/structure.py
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
class MemMapDataset(GeneralStructure):
    """
    Dataset structure for memory-mapped numpy arrays and props.pkl files.
    """

    _ext = ".mmap"
    _extra_files = ["props.pkl"]

    @property
    def load_fn(self):
        return np.memmap

    def save_preprocess(self, preprocess_path, data_keys, data_dict, extra_data_keys, extra_data_types) -> List[str]:
        local_paths = []
        for key in data_keys:
            local_path = self.join_and_ext(preprocess_path, key)
            out = np.memmap(local_path, mode="w+", dtype=data_dict[key].dtype, shape=data_dict[key].shape)
            out[:] = data_dict.pop(key)[:]
            out.flush()
            local_paths.append(local_path)

        # save smiles and subset
        local_path = p_join(preprocess_path, "props.pkl")

        # assert that (required) pkl keys are present in data_dict
        assert all([key in data_dict.keys() for key in extra_data_keys])

        # store unique and inverse indices for str-based pkl keys
        for key in extra_data_keys:
            if extra_data_types[key] == str:
                data_dict[key] = np.unique(data_dict[key], return_inverse=True)

        with open(local_path, "wb") as f:
            pkl.dump(data_dict, f)

        local_paths.append(local_path)
        return local_paths

    def load_extra_files(self, data, preprocess_path, data_keys, pkl_data_keys, overwrite):
        filename = p_join(preprocess_path, "props.pkl")
        pull_locally(filename, overwrite=overwrite)
        with open(filename, "rb") as f:
            tmp = pkl.load(f)
            all_pkl_keys = set(tmp.keys()) - set(data_keys)
            # assert required pkl_keys are present in all_pkl_keys
            assert all([key in all_pkl_keys for key in pkl_data_keys])
            for key in all_pkl_keys:
                x = tmp.pop(key)
                if len(x) == 2:
                    data[key] = x[0][x[1]]
                else:
                    data[key] = x
        return data

ZarrDataset

Bases: GeneralStructure

Dataset structure for zarr files.

Source code in openqdc/datasets/structure.py
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
class ZarrDataset(GeneralStructure):
    """
    Dataset structure for zarr files.
    """

    _ext = ".zip"
    _extra_files = ["metadata.zip"]
    _zarr_version = 2

    @property
    def load_fn(self):
        return zarr.open

    def unpack(self, data):
        return data[:]

    def save_preprocess(self, preprocess_path, data_keys, data_dict, extra_data_keys, extra_data_types) -> List[str]:
        # os.makedirs(p_join(ds.root, "zips",  ds.__name__), exist_ok=True)
        local_paths = []
        for key, value in data_dict.items():
            if key not in data_keys:
                continue
            zarr_path = self.join_and_ext(preprocess_path, key)
            value = data_dict.pop(key)
            z = zarr.open(
                zarr.storage.ZipStore(zarr_path),
                "w",
                zarr_version=self._zarr_version,
                shape=value.shape,
                dtype=value.dtype,
            )
            z[:] = value[:]
            local_paths.append(zarr_path)
            # if key in attrs:
            #    z.attrs.update(attrs[key])

        metadata = p_join(preprocess_path, "metadata.zip")

        group = zarr.group(zarr.storage.ZipStore(metadata))

        for key in extra_data_keys:
            if extra_data_types[key] == str:
                data_dict[key] = np.unique(data_dict[key], return_inverse=True)

        for key, value in data_dict.items():
            # sub=group.create_group(key)
            if key in ["name", "subset"]:
                data = group.create_dataset(key, shape=value[0].shape, dtype=value[0].dtype)
                data[:] = value[0][:]
                data2 = group.create_dataset(key + "_ptr", shape=value[1].shape, dtype=np.int32)
                data2[:] = value[1][:]
            else:
                data = group.create_dataset(key, shape=value.shape, dtype=value.dtype)
                data[:] = value[:]
        local_paths.append(metadata)
        return local_paths

    def load_extra_files(self, data, preprocess_path, data_keys, pkl_data_keys, overwrite):
        filename = self.join_and_ext(preprocess_path, "metadata")
        pull_locally(filename, overwrite=overwrite)
        tmp = self.load_fn(filename)
        all_pkl_keys = set(tmp.keys()) - set(data_keys)
        # assert required pkl_keys are present in all_pkl_keys
        assert all([key in all_pkl_keys for key in pkl_data_keys])
        for key in all_pkl_keys:
            if key not in pkl_data_keys:
                data[key] = tmp[key][:][tmp[key][:]]
            else:
                data[key] = tmp[key][:]
        return data