Skip to content

Models

LUMEBaseModel

Bases: BaseModel, ABC

Abstract base class for models using lume-model variables.

Inheriting classes must define the evaluate method and variable names must be unique (respectively). Models build using this framework will be compatible with the lume-epics EPICS server and associated tools.

Attributes:

Name Type Description
input_variables list[ScalarVariable]

List defining the input variables and their order.

output_variables list[ScalarVariable]

List defining the output variables and their order.

input_validation_config Optional[dict[str, ConfigEnum]]

Determines the behavior during input validation by specifying the validation config for each input variable: {var_name: value}. Value can be "warn", "error", or None.

output_validation_config Optional[dict[str, ConfigEnum]]

Determines the behavior during output validation by specifying the validation config for each output variable: {var_name: value}. Value can be "warn", "error", or None.

Source code in lume_model/base.py
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
class LUMEBaseModel(BaseModel, ABC):
    """Abstract base class for models using lume-model variables.

    Inheriting classes must define the evaluate method and variable names must be unique (respectively).
    Models build using this framework will be compatible with the lume-epics EPICS server and associated tools.

    Attributes:
        input_variables: List defining the input variables and their order.
        output_variables: List defining the output variables and their order.
        input_validation_config: Determines the behavior during input validation by specifying the validation
          config for each input variable: {var_name: value}. Value can be "warn", "error", or None.
        output_validation_config: Determines the behavior during output validation by specifying the validation
          config for each output variable: {var_name: value}. Value can be "warn", "error", or None.
    """
    input_variables: list[ScalarVariable]
    output_variables: list[ScalarVariable]
    input_validation_config: Optional[dict[str, ConfigEnum]] = None
    output_validation_config: Optional[dict[str, ConfigEnum]] = None

    model_config = ConfigDict(arbitrary_types_allowed=True, validate_assignment=True)

    @field_validator("input_variables", "output_variables", mode="before")
    def validate_input_variables(cls, value):
        new_value = []
        if isinstance(value, dict):
            for name, val in value.items():
                if isinstance(val, dict):
                    variable_class = get_variable(val["variable_class"])
                    new_value.append(variable_class(name=name, **val))
                elif isinstance(val, ScalarVariable):
                    new_value.append(val)
                else:
                    raise TypeError(f"type {type(val)} not supported")
        elif isinstance(value, list):
            new_value = value
        return new_value

    def __init__(self, *args, **kwargs):
        """Initializes LUMEBaseModel.

        Args:
            *args: Accepts a single argument which is the model configuration as dictionary, YAML or JSON
              formatted string or file path.
            **kwargs: See class attributes.
        """
        if len(args) == 1:
            if len(kwargs) > 0:
                raise ValueError("Cannot specify YAML string and keyword arguments for LUMEBaseModel init.")
            super().__init__(**parse_config(args[0], self.model_fields))
        elif len(args) > 1:
            raise ValueError(
                "Arguments to LUMEBaseModel must be either a single YAML string "
                "or keyword arguments passed directly to pydantic."
            )
        else:
            super().__init__(**kwargs)

    @field_validator("input_variables", "output_variables")
    def unique_variable_names(cls, value):
        verify_unique_variable_names(value)
        return value

    @field_validator("input_variables")
    def verify_input_default_value(cls, value):
        """Verifies that input variables have the required default values."""
        for var in value:
            if var.default_value is None or not var.default_value:
                raise ValueError(f"Input variable {var.name} must have a default value.")
        return value

    @property
    def input_names(self) -> list[str]:
        return [var.name for var in self.input_variables]

    @property
    def output_names(self) -> list[str]:
        return [var.name for var in self.output_variables]

    @property
    def default_input_validation_config(self) -> dict[str, ConfigEnum]:
        """Determines default behavior during input validation (if input_validation_config is None)."""
        return {var.name: var.default_validation_config for var in self.input_variables}

    @property
    def default_output_validation_config(self) -> dict[str, ConfigEnum]:
        """Determines default behavior during output validation (if output_validation_config is None)."""
        return {var.name: var.default_validation_config for var in self.output_variables}

    def evaluate(self, input_dict: dict[str, Any]) -> dict[str, Any]:
        """Main evaluation function, child classes must implement the _evaluate method."""
        validated_input_dict = self.input_validation(input_dict)
        output_dict = self._evaluate(validated_input_dict)
        self.output_validation(output_dict)
        return output_dict

    @abstractmethod
    def _evaluate(self, input_dict: dict[str, Any]) -> dict[str, Any]:
        pass

    def input_validation(self, input_dict: dict[str, Any]) -> dict[str, Any]:
        for name, value in input_dict.items():
            _config = None if self.input_validation_config is None else self.input_validation_config.get(name)
            var = self.input_variables[self.input_names.index(name)]
            var.validate_value(value, config=_config)
        return input_dict

    def output_validation(self, output_dict: dict[str, Any]) -> dict[str, Any]:
        for name, value in output_dict.items():
            _config = None if self.output_validation_config is None else self.output_validation_config.get(name)
            var = self.output_variables[self.output_names.index(name)]
            var.validate_value(value, config=_config)
        return output_dict

    def to_json(self, **kwargs) -> str:
        return json_dumps(self, **kwargs)

    def model_dump(self, **kwargs) -> dict[str, Any]:
        config = super().model_dump(**kwargs)
        config["input_variables"] = [var.model_dump() for var in self.input_variables]
        config["output_variables"] = [var.model_dump() for var in self.output_variables]
        return {"model_class": self.__class__.__name__} | config

    def json(self, **kwargs) -> str:
        result = self.to_json(**kwargs)
        config = json.loads(result)
        return json.dumps(config)

    def yaml(
            self,
            base_key: str = "",
            file_prefix: str = "",
            save_models: bool = False,
    ) -> str:
        """Serializes the object and returns a YAML formatted string defining the model.

        Args:
            base_key: Base key for serialization.
            file_prefix: Prefix for generated filenames.
            save_models: Determines whether models are saved to file.

        Returns:
            YAML formatted string defining the model.
        """
        output = json.loads(
            self.to_json(
                base_key=base_key,
                file_prefix=file_prefix,
                save_models=save_models,
            )
        )
        s = yaml.dump(output, default_flow_style=None, sort_keys=False)
        return s

    def dump(
            self,
            file: Union[str, os.PathLike],
            base_key: str = "",
            save_models: bool = True,
    ):
        """Returns and optionally saves YAML formatted string defining the model.

        Args:
            file: File path to which the YAML formatted string and corresponding files are saved.
            base_key: Base key for serialization.
            save_models: Determines whether models are saved to file.
        """
        file_prefix = os.path.splitext(os.path.abspath(file))[0]
        with open(file, "w") as f:
            f.write(
                self.yaml(
                    base_key=base_key,
                    file_prefix=file_prefix,
                    save_models=save_models,
                )
            )

    @classmethod
    def from_file(cls, filename: str):
        if not os.path.exists(filename):
            raise OSError(f"File {filename} is not found.")
        with open(filename, "r") as file:
            return cls.from_yaml(file)

    @classmethod
    def from_yaml(cls, yaml_obj: [str, TextIOWrapper]):
        return cls.model_validate(parse_config(yaml_obj, cls.model_fields))

default_input_validation_config property

Determines default behavior during input validation (if input_validation_config is None).

default_output_validation_config property

Determines default behavior during output validation (if output_validation_config is None).

__init__(*args, **kwargs)

Initializes LUMEBaseModel.

Parameters:

Name Type Description Default
*args

Accepts a single argument which is the model configuration as dictionary, YAML or JSON formatted string or file path.

()
**kwargs

See class attributes.

{}
Source code in lume_model/base.py
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
def __init__(self, *args, **kwargs):
    """Initializes LUMEBaseModel.

    Args:
        *args: Accepts a single argument which is the model configuration as dictionary, YAML or JSON
          formatted string or file path.
        **kwargs: See class attributes.
    """
    if len(args) == 1:
        if len(kwargs) > 0:
            raise ValueError("Cannot specify YAML string and keyword arguments for LUMEBaseModel init.")
        super().__init__(**parse_config(args[0], self.model_fields))
    elif len(args) > 1:
        raise ValueError(
            "Arguments to LUMEBaseModel must be either a single YAML string "
            "or keyword arguments passed directly to pydantic."
        )
    else:
        super().__init__(**kwargs)

dump(file, base_key='', save_models=True)

Returns and optionally saves YAML formatted string defining the model.

Parameters:

Name Type Description Default
file Union[str, PathLike]

File path to which the YAML formatted string and corresponding files are saved.

required
base_key str

Base key for serialization.

''
save_models bool

Determines whether models are saved to file.

True
Source code in lume_model/base.py
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
def dump(
        self,
        file: Union[str, os.PathLike],
        base_key: str = "",
        save_models: bool = True,
):
    """Returns and optionally saves YAML formatted string defining the model.

    Args:
        file: File path to which the YAML formatted string and corresponding files are saved.
        base_key: Base key for serialization.
        save_models: Determines whether models are saved to file.
    """
    file_prefix = os.path.splitext(os.path.abspath(file))[0]
    with open(file, "w") as f:
        f.write(
            self.yaml(
                base_key=base_key,
                file_prefix=file_prefix,
                save_models=save_models,
            )
        )

evaluate(input_dict)

Main evaluation function, child classes must implement the _evaluate method.

Source code in lume_model/base.py
306
307
308
309
310
311
def evaluate(self, input_dict: dict[str, Any]) -> dict[str, Any]:
    """Main evaluation function, child classes must implement the _evaluate method."""
    validated_input_dict = self.input_validation(input_dict)
    output_dict = self._evaluate(validated_input_dict)
    self.output_validation(output_dict)
    return output_dict

verify_input_default_value(value)

Verifies that input variables have the required default values.

Source code in lume_model/base.py
280
281
282
283
284
285
286
@field_validator("input_variables")
def verify_input_default_value(cls, value):
    """Verifies that input variables have the required default values."""
    for var in value:
        if var.default_value is None or not var.default_value:
            raise ValueError(f"Input variable {var.name} must have a default value.")
    return value

yaml(base_key='', file_prefix='', save_models=False)

Serializes the object and returns a YAML formatted string defining the model.

Parameters:

Name Type Description Default
base_key str

Base key for serialization.

''
file_prefix str

Prefix for generated filenames.

''
save_models bool

Determines whether models are saved to file.

False

Returns:

Type Description
str

YAML formatted string defining the model.

Source code in lume_model/base.py
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
def yaml(
        self,
        base_key: str = "",
        file_prefix: str = "",
        save_models: bool = False,
) -> str:
    """Serializes the object and returns a YAML formatted string defining the model.

    Args:
        base_key: Base key for serialization.
        file_prefix: Prefix for generated filenames.
        save_models: Determines whether models are saved to file.

    Returns:
        YAML formatted string defining the model.
    """
    output = json.loads(
        self.to_json(
            base_key=base_key,
            file_prefix=file_prefix,
            save_models=save_models,
        )
    )
    s = yaml.dump(output, default_flow_style=None, sort_keys=False)
    return s

TorchModel

Bases: LUMEBaseModel

LUME-model class for torch models.

By default, the models are assumed to be fixed, so all gradient computation is deactivated and the model and transformers are put in evaluation mode.

Attributes:

Name Type Description
model Module

The torch base model.

input_variables list[ScalarVariable]

List defining the input variables and their order.

output_variables list[ScalarVariable]

List defining the output variables and their order.

input_transformers list[Union[ReversibleInputTransform, Linear]]

List of transformer objects to apply to input before passing to model.

output_transformers list[Union[ReversibleInputTransform, Linear]]

List of transformer objects to apply to output of model.

output_format str

Determines format of outputs: "tensor" or "raw".

device Union[device, str]

Device on which the model will be evaluated. Defaults to "cpu".

fixed_model bool

If true, the model and transformers are put in evaluation mode and all gradient computation is deactivated.

precision str

Precision of the model, either "double" or "single".

Source code in lume_model/models/torch_model.py
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
class TorchModel(LUMEBaseModel):
    """LUME-model class for torch models.

    By default, the models are assumed to be fixed, so all gradient computation is deactivated and the model and
    transformers are put in evaluation mode.

    Attributes:
        model: The torch base model.
        input_variables: List defining the input variables and their order.
        output_variables: List defining the output variables and their order.
        input_transformers: List of transformer objects to apply to input before passing to model.
        output_transformers: List of transformer objects to apply to output of model.
        output_format: Determines format of outputs: "tensor" or "raw".
        device: Device on which the model will be evaluated. Defaults to "cpu".
        fixed_model: If true, the model and transformers are put in evaluation mode and all gradient
          computation is deactivated.
        precision: Precision of the model, either "double" or "single".
    """
    model: torch.nn.Module
    input_transformers: list[Union[ReversibleInputTransform, torch.nn.Linear]] = None
    output_transformers: list[Union[ReversibleInputTransform, torch.nn.Linear]] = None
    output_format: str = "tensor"
    device: Union[torch.device, str] = "cpu"
    fixed_model: bool = True
    precision: str = "double"

    def __init__(self, *args, **kwargs):
        """Initializes TorchModel.

        Args:
            *args: Accepts a single argument which is the model configuration as dictionary, YAML or JSON
              formatted string or file path.
            **kwargs: See class attributes.
        """
        super().__init__(*args, **kwargs)
        self.input_transformers = [] if self.input_transformers is None else self.input_transformers
        self.output_transformers = [] if self.output_transformers is None else self.output_transformers

        # dtype property sets precision across model and transformers
        self.dtype;

        # fixed model: set full model in eval mode and deactivate all gradients
        if self.fixed_model:
            self.model.eval().requires_grad_(False)
            for t in self.input_transformers + self.output_transformers:
                if isinstance(t, torch.nn.Module):
                    t.eval().requires_grad_(False)

        # ensure consistent device
        self.to(self.device)

    @property
    def dtype(self):
        if self.precision == "double":
            self._dtype = torch.double
        elif self.precision == "single":
            self._dtype = torch.float
        else:
            raise ValueError(
                f"Unknown precision {self.precision}, "
                f"expected one of ['double', 'single']."
            )
        self._set_precision(self._dtype)
        return self._dtype

    @property
    def _tkwargs(self):
        return {"device": self.device, "dtype": self.dtype}

    @field_validator("model", mode="before")
    def validate_torch_model(cls, v):
        if isinstance(v, (str, os.PathLike)):
            if os.path.exists(v):
                v = torch.load(v)
            else:
                raise OSError(f"File {v} is not found.")
        return v

    @field_validator("input_transformers", "output_transformers", mode="before")
    def validate_transformers(cls, v):
        if not isinstance(v, list):
            raise ValueError("Transformers must be passed as list.")
        loaded_transformers = []
        for t in v:
            if isinstance(t, (str, os.PathLike)):
                if os.path.exists(t):
                    t = torch.load(t)
                else:
                    raise OSError(f"File {t} is not found.")
            loaded_transformers.append(t)
        v = loaded_transformers
        return v

    @field_validator("output_format")
    def validate_output_format(cls, v):
        supported_formats = ["tensor", "variable", "raw"]
        if v not in supported_formats:
            raise ValueError(f"Unknown output format {v}, expected one of {supported_formats}.")
        return v

    def _set_precision(self, value: torch.dtype):
        """Sets the precision of the model."""
        self.model.to(dtype=value)
        for t in self.input_transformers + self.output_transformers:
            if isinstance(t, torch.nn.Module):
                t.to(dtype=value)

    def _evaluate(
            self,
            input_dict: dict[str, Union[float, torch.Tensor]],
    ) -> dict[str, Union[float, torch.Tensor]]:
        """Evaluates model on the given input dictionary.

        Args:
            input_dict: Input dictionary on which to evaluate the model.

        Returns:
            Dictionary of output variable names to values.
        """
        formatted_inputs = self._format_inputs(input_dict)
        input_tensor = self._arrange_inputs(formatted_inputs)
        input_tensor = self._transform_inputs(input_tensor)
        output_tensor = self.model(input_tensor)
        output_tensor = self._transform_outputs(output_tensor)
        parsed_outputs = self._parse_outputs(output_tensor)
        output_dict = self._prepare_outputs(parsed_outputs)
        return output_dict

    def input_validation(self, input_dict: dict[str, Union[float, torch.Tensor]]):
        """Validates input dictionary before evaluation.

        Args:
            input_dict: Input dictionary to validate.

        Returns:
            Validated input dictionary.
        """
        # validate input type (ints only are cast to floats for scalars)
        validated_input = InputDictModel(input_dict=input_dict).input_dict
        # format inputs as tensors w/o changing the dtype
        formatted_inputs = self._format_inputs(validated_input)
        # check default values for missing inputs
        filled_inputs = self._fill_default_inputs(formatted_inputs)
        # itemize inputs for validation
        itemized_inputs = self._itemize_dict(filled_inputs)

        for ele in itemized_inputs:
            # validate values that were in the torch tensor
            # any ints in the torch tensor will be cast to floats by Pydantic
            # but others will be caught, e.g. booleans
            ele = InputDictModel(input_dict=ele).input_dict
            # validate each value based on its var class and config
            super().input_validation(ele)

        # return the validated input dict for consistency w/ casting ints to floats
        if any([isinstance(value, torch.Tensor) for value in validated_input.values()]):
            validated_input = {k: v.to(**self._tkwargs) for k, v in validated_input.items()}

        return validated_input

    def output_validation(self, output_dict: dict[str, Union[float, torch.Tensor]]):
        """Itemizes tensors before performing output validation."""
        itemized_outputs = self._itemize_dict(output_dict)
        for ele in itemized_outputs:
            super().output_validation(ele)

    def random_input(self, n_samples: int = 1) -> dict[str, torch.Tensor]:
        """Generates random input(s) for the model.

        Args:
            n_samples: Number of random samples to generate.

        Returns:
            Dictionary of input variable names to tensors.
        """
        input_dict = {}
        for var in self.input_variables:
            if isinstance(var, ScalarVariable):
                input_dict[var.name] = var.value_range[0] + torch.rand(size=(n_samples,)) * (
                            var.value_range[1] - var.value_range[0])
            else:
                torch.tensor(var.default_value, **self._tkwargs).repeat((n_samples, 1))
        return input_dict

    def random_evaluate(self, n_samples: int = 1) -> dict[str, Union[float, torch.Tensor]]:
        """Returns random evaluation(s) of the model.

        Args:
            n_samples: Number of random samples to evaluate.

        Returns:
            Dictionary of variable names to outputs.
        """
        random_input = self.random_input(n_samples)
        return self.evaluate(random_input)

    def to(self, device: Union[torch.device, str]):
        """Updates the device for the model, transformers and default values.

        Args:
            device: Device on which the model will be evaluated.
        """
        self.model.to(device)
        for t in self.input_transformers + self.output_transformers:
            if isinstance(t, torch.nn.Module):
                t.to(device)
        self.device = device

    def insert_input_transformer(self, new_transformer: ReversibleInputTransform, loc: int):
        """Inserts an additional input transformer at the given location.

        Args:
            new_transformer: New transformer to add.
            loc: Location where the new transformer shall be added to the transformer list.
        """
        self.input_transformers = (self.input_transformers[:loc] + [new_transformer] +
                                   self.input_transformers[loc:])

    def insert_output_transformer(self, new_transformer: ReversibleInputTransform, loc: int):
        """Inserts an additional output transformer at the given location.

        Args:
            new_transformer: New transformer to add.
            loc: Location where the new transformer shall be added to the transformer list.
        """
        self.output_transformers = (self.output_transformers[:loc] + [new_transformer] +
                                    self.output_transformers[loc:])

    def update_input_variables_to_transformer(self, transformer_loc: int) -> list[ScalarVariable]:
        """Returns input variables updated to the transformer at the given location.

        Updated are the value ranges and default of the input variables. This allows, e.g., to add a
        calibration transformer and to update the input variable specification accordingly.

        Args:
            transformer_loc: The location of the input transformer to adjust for.

        Returns:
            The updated input variables.
        """
        x_old = {
            "min": torch.tensor([var.value_range[0] for var in self.input_variables], dtype=self.dtype),
            "max": torch.tensor([var.value_range[1] for var in self.input_variables], dtype=self.dtype),
            "default": torch.tensor([var.default_value for var in self.input_variables], dtype=self.dtype),
        }
        x_new = {}
        for key in x_old.keys():
            x = x_old[key]
            # compute previous limits at transformer location
            for i in range(transformer_loc):
                if isinstance(self.input_transformers[i], ReversibleInputTransform):
                    x = self.input_transformers[i].transform(x)
                else:
                    x = self.input_transformers[i](x)
            # untransform of transformer to adjust for
            if isinstance(self.input_transformers[transformer_loc], ReversibleInputTransform):
                x = self.input_transformers[transformer_loc].untransform(x)
            else:
                w = self.input_transformers[transformer_loc].weight
                b = self.input_transformers[transformer_loc].bias
                x = torch.matmul((x - b), torch.linalg.inv(w.T))
            # backtrack through transformers
            for transformer in self.input_transformers[:transformer_loc][::-1]:
                if isinstance(self.input_transformers[transformer_loc], ReversibleInputTransform):
                    x = transformer.untransform(x)
                else:
                    w, b = transformer.weight, transformer.bias
                    x = torch.matmul((x - b), torch.linalg.inv(w.T))
            x_new[key] = x
        updated_variables = deepcopy(self.input_variables)
        for i, var in enumerate(updated_variables):
            var.value_range = [x_new["min"][i].item(), x_new["max"][i].item()]
            var.default_value = x_new["default"][i].item()
        return updated_variables

    def _format_inputs(
            self,
            input_dict: dict[str, Union[float, torch.Tensor]],
    ) -> dict[str, torch.Tensor]:
        """Formats values of the input dictionary as tensors.

        Args:
            input_dict: Dictionary of input variable names to values.

        Returns:
            Dictionary of input variable names to tensors.
        """
        formatted_inputs = {}
        for var_name, value in input_dict.items():
            v = value if isinstance(value, torch.Tensor) else torch.tensor(value)
            formatted_inputs[var_name] = v.squeeze()
        return formatted_inputs

    def _fill_default_inputs(self, input_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
        """Fills missing input variables with default values.

        Args:
            input_dict: Dictionary of input variable names to tensors.

        Returns:
            Dictionary of input variable names to tensors with default values for missing inputs.
        """
        for var in self.input_variables:
            if var.name not in input_dict.keys():
                input_dict[var.name] = torch.tensor(var.default_value, **self._tkwargs)
        return input_dict

    def _arrange_inputs(self, formatted_inputs: dict[str, torch.Tensor]) -> torch.Tensor:
        """Enforces order of input variables.

        Enforces the order of the input variables to be passed to the transformers and model and updates the
        returned tensor with default values for any inputs that are missing.

        Args:
            formatted_inputs: Dictionary of input variable names to tensors.

        Returns:
            Ordered input tensor to be passed to the transformers.
        """
        default_tensor = torch.tensor(
            [var.default_value for var in self.input_variables], **self._tkwargs
        )

        # determine input shape
        input_shapes = [formatted_inputs[k].shape for k in formatted_inputs.keys()]
        if not all(ele == input_shapes[0] for ele in input_shapes):
            raise ValueError("Inputs have inconsistent shapes.")

        input_tensor = torch.tile(default_tensor, dims=(*input_shapes[0], 1))
        for key, value in formatted_inputs.items():
            input_tensor[..., self.input_names.index(key)] = value

        if input_tensor.shape[-1] != len(self.input_names):
            raise ValueError(
                f"""
                Last dimension of input tensor doesn't match the expected number of inputs\n
                received: {default_tensor.shape}, expected {len(self.input_names)} as the last dimension
                """
            )
        return input_tensor

    def _transform_inputs(self, input_tensor: torch.Tensor) -> torch.Tensor:
        """Applies transformations to the inputs.

        Args:
            input_tensor: Ordered input tensor to be passed to the transformers.

        Returns:
            Tensor of transformed inputs to be passed to the model.
        """
        for transformer in self.input_transformers:
            if isinstance(transformer, ReversibleInputTransform):
                input_tensor = transformer.transform(input_tensor)
            else:
                input_tensor = transformer(input_tensor)
        return input_tensor

    def _transform_outputs(self, output_tensor: torch.Tensor) -> torch.Tensor:
        """(Un-)Transforms the model output tensor.

        Args:
            output_tensor: Output tensor from the model.

        Returns:
            (Un-)Transformed output tensor.
        """
        for transformer in self.output_transformers:
            if isinstance(transformer, ReversibleInputTransform):
                output_tensor = transformer.untransform(output_tensor)
            else:
                w, b = transformer.weight, transformer.bias
                output_tensor = torch.matmul((output_tensor - b), torch.linalg.inv(w.T))
        return output_tensor

    def _parse_outputs(self, output_tensor: torch.Tensor) -> dict[str, torch.Tensor]:
        """Constructs dictionary from model output tensor.

        Args:
            output_tensor: (Un-)transformed output tensor from the model.

        Returns:
            Dictionary of output variable names to (un-)transformed tensors.
        """
        parsed_outputs = {}
        if output_tensor.dim() in [0, 1]:
            output_tensor = output_tensor.unsqueeze(0)
        if len(self.output_names) == 1:
            parsed_outputs[self.output_names[0]] = output_tensor.squeeze()
        else:
            for idx, output_name in enumerate(self.output_names):
                parsed_outputs[output_name] = output_tensor[..., idx].squeeze()
        return parsed_outputs

    def _prepare_outputs(
            self,
            parsed_outputs: dict[str, torch.Tensor],
    ) -> dict[str, Union[float, torch.Tensor]]:
        """Updates and returns outputs according to output_format.

        Updates the output variables within the model to reflect the new values.

        Args:
            parsed_outputs: Dictionary of output variable names to transformed tensors.

        Returns:
            Dictionary of output variable names to values depending on output_format.
        """
        if self.output_format.lower() == "tensor":
            return parsed_outputs
        else:
            return {key: value.item() if value.squeeze().dim() == 0 else value
                    for key, value in parsed_outputs.items()}

    @staticmethod
    def _itemize_dict(d: dict[str, Union[float, torch.Tensor]]) -> list[dict[str, Union[float, torch.Tensor]]]:
        """Itemizes the given in-/output dictionary.

        Args:
            d: Dictionary to itemize.

        Returns:
            List of in-/output dictionaries, each containing only a single value per in-/output.
        """
        has_tensors = any([isinstance(value, torch.Tensor) for value in d.values()])
        itemized_dicts = []
        if has_tensors:
            for k, v in d.items():
                for i, ele in enumerate(v.flatten()):
                    if i >= len(itemized_dicts):
                        itemized_dicts.append({k: ele.item()})
                    else:
                        itemized_dicts[i][k] = ele.item()
        else:
            itemized_dicts = [d]
        return itemized_dicts

__init__(*args, **kwargs)

Initializes TorchModel.

Parameters:

Name Type Description Default
*args

Accepts a single argument which is the model configuration as dictionary, YAML or JSON formatted string or file path.

()
**kwargs

See class attributes.

{}
Source code in lume_model/models/torch_model.py
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
def __init__(self, *args, **kwargs):
    """Initializes TorchModel.

    Args:
        *args: Accepts a single argument which is the model configuration as dictionary, YAML or JSON
          formatted string or file path.
        **kwargs: See class attributes.
    """
    super().__init__(*args, **kwargs)
    self.input_transformers = [] if self.input_transformers is None else self.input_transformers
    self.output_transformers = [] if self.output_transformers is None else self.output_transformers

    # dtype property sets precision across model and transformers
    self.dtype;

    # fixed model: set full model in eval mode and deactivate all gradients
    if self.fixed_model:
        self.model.eval().requires_grad_(False)
        for t in self.input_transformers + self.output_transformers:
            if isinstance(t, torch.nn.Module):
                t.eval().requires_grad_(False)

    # ensure consistent device
    self.to(self.device)

input_validation(input_dict)

Validates input dictionary before evaluation.

Parameters:

Name Type Description Default
input_dict dict[str, Union[float, Tensor]]

Input dictionary to validate.

required

Returns:

Type Description

Validated input dictionary.

Source code in lume_model/models/torch_model.py
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
def input_validation(self, input_dict: dict[str, Union[float, torch.Tensor]]):
    """Validates input dictionary before evaluation.

    Args:
        input_dict: Input dictionary to validate.

    Returns:
        Validated input dictionary.
    """
    # validate input type (ints only are cast to floats for scalars)
    validated_input = InputDictModel(input_dict=input_dict).input_dict
    # format inputs as tensors w/o changing the dtype
    formatted_inputs = self._format_inputs(validated_input)
    # check default values for missing inputs
    filled_inputs = self._fill_default_inputs(formatted_inputs)
    # itemize inputs for validation
    itemized_inputs = self._itemize_dict(filled_inputs)

    for ele in itemized_inputs:
        # validate values that were in the torch tensor
        # any ints in the torch tensor will be cast to floats by Pydantic
        # but others will be caught, e.g. booleans
        ele = InputDictModel(input_dict=ele).input_dict
        # validate each value based on its var class and config
        super().input_validation(ele)

    # return the validated input dict for consistency w/ casting ints to floats
    if any([isinstance(value, torch.Tensor) for value in validated_input.values()]):
        validated_input = {k: v.to(**self._tkwargs) for k, v in validated_input.items()}

    return validated_input

insert_input_transformer(new_transformer, loc)

Inserts an additional input transformer at the given location.

Parameters:

Name Type Description Default
new_transformer ReversibleInputTransform

New transformer to add.

required
loc int

Location where the new transformer shall be added to the transformer list.

required
Source code in lume_model/models/torch_model.py
224
225
226
227
228
229
230
231
232
def insert_input_transformer(self, new_transformer: ReversibleInputTransform, loc: int):
    """Inserts an additional input transformer at the given location.

    Args:
        new_transformer: New transformer to add.
        loc: Location where the new transformer shall be added to the transformer list.
    """
    self.input_transformers = (self.input_transformers[:loc] + [new_transformer] +
                               self.input_transformers[loc:])

insert_output_transformer(new_transformer, loc)

Inserts an additional output transformer at the given location.

Parameters:

Name Type Description Default
new_transformer ReversibleInputTransform

New transformer to add.

required
loc int

Location where the new transformer shall be added to the transformer list.

required
Source code in lume_model/models/torch_model.py
234
235
236
237
238
239
240
241
242
def insert_output_transformer(self, new_transformer: ReversibleInputTransform, loc: int):
    """Inserts an additional output transformer at the given location.

    Args:
        new_transformer: New transformer to add.
        loc: Location where the new transformer shall be added to the transformer list.
    """
    self.output_transformers = (self.output_transformers[:loc] + [new_transformer] +
                                self.output_transformers[loc:])

output_validation(output_dict)

Itemizes tensors before performing output validation.

Source code in lume_model/models/torch_model.py
176
177
178
179
180
def output_validation(self, output_dict: dict[str, Union[float, torch.Tensor]]):
    """Itemizes tensors before performing output validation."""
    itemized_outputs = self._itemize_dict(output_dict)
    for ele in itemized_outputs:
        super().output_validation(ele)

random_evaluate(n_samples=1)

Returns random evaluation(s) of the model.

Parameters:

Name Type Description Default
n_samples int

Number of random samples to evaluate.

1

Returns:

Type Description
dict[str, Union[float, Tensor]]

Dictionary of variable names to outputs.

Source code in lume_model/models/torch_model.py
200
201
202
203
204
205
206
207
208
209
210
def random_evaluate(self, n_samples: int = 1) -> dict[str, Union[float, torch.Tensor]]:
    """Returns random evaluation(s) of the model.

    Args:
        n_samples: Number of random samples to evaluate.

    Returns:
        Dictionary of variable names to outputs.
    """
    random_input = self.random_input(n_samples)
    return self.evaluate(random_input)

random_input(n_samples=1)

Generates random input(s) for the model.

Parameters:

Name Type Description Default
n_samples int

Number of random samples to generate.

1

Returns:

Type Description
dict[str, Tensor]

Dictionary of input variable names to tensors.

Source code in lume_model/models/torch_model.py
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
def random_input(self, n_samples: int = 1) -> dict[str, torch.Tensor]:
    """Generates random input(s) for the model.

    Args:
        n_samples: Number of random samples to generate.

    Returns:
        Dictionary of input variable names to tensors.
    """
    input_dict = {}
    for var in self.input_variables:
        if isinstance(var, ScalarVariable):
            input_dict[var.name] = var.value_range[0] + torch.rand(size=(n_samples,)) * (
                        var.value_range[1] - var.value_range[0])
        else:
            torch.tensor(var.default_value, **self._tkwargs).repeat((n_samples, 1))
    return input_dict

to(device)

Updates the device for the model, transformers and default values.

Parameters:

Name Type Description Default
device Union[device, str]

Device on which the model will be evaluated.

required
Source code in lume_model/models/torch_model.py
212
213
214
215
216
217
218
219
220
221
222
def to(self, device: Union[torch.device, str]):
    """Updates the device for the model, transformers and default values.

    Args:
        device: Device on which the model will be evaluated.
    """
    self.model.to(device)
    for t in self.input_transformers + self.output_transformers:
        if isinstance(t, torch.nn.Module):
            t.to(device)
    self.device = device

update_input_variables_to_transformer(transformer_loc)

Returns input variables updated to the transformer at the given location.

Updated are the value ranges and default of the input variables. This allows, e.g., to add a calibration transformer and to update the input variable specification accordingly.

Parameters:

Name Type Description Default
transformer_loc int

The location of the input transformer to adjust for.

required

Returns:

Type Description
list[ScalarVariable]

The updated input variables.

Source code in lume_model/models/torch_model.py
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
def update_input_variables_to_transformer(self, transformer_loc: int) -> list[ScalarVariable]:
    """Returns input variables updated to the transformer at the given location.

    Updated are the value ranges and default of the input variables. This allows, e.g., to add a
    calibration transformer and to update the input variable specification accordingly.

    Args:
        transformer_loc: The location of the input transformer to adjust for.

    Returns:
        The updated input variables.
    """
    x_old = {
        "min": torch.tensor([var.value_range[0] for var in self.input_variables], dtype=self.dtype),
        "max": torch.tensor([var.value_range[1] for var in self.input_variables], dtype=self.dtype),
        "default": torch.tensor([var.default_value for var in self.input_variables], dtype=self.dtype),
    }
    x_new = {}
    for key in x_old.keys():
        x = x_old[key]
        # compute previous limits at transformer location
        for i in range(transformer_loc):
            if isinstance(self.input_transformers[i], ReversibleInputTransform):
                x = self.input_transformers[i].transform(x)
            else:
                x = self.input_transformers[i](x)
        # untransform of transformer to adjust for
        if isinstance(self.input_transformers[transformer_loc], ReversibleInputTransform):
            x = self.input_transformers[transformer_loc].untransform(x)
        else:
            w = self.input_transformers[transformer_loc].weight
            b = self.input_transformers[transformer_loc].bias
            x = torch.matmul((x - b), torch.linalg.inv(w.T))
        # backtrack through transformers
        for transformer in self.input_transformers[:transformer_loc][::-1]:
            if isinstance(self.input_transformers[transformer_loc], ReversibleInputTransform):
                x = transformer.untransform(x)
            else:
                w, b = transformer.weight, transformer.bias
                x = torch.matmul((x - b), torch.linalg.inv(w.T))
        x_new[key] = x
    updated_variables = deepcopy(self.input_variables)
    for i, var in enumerate(updated_variables):
        var.value_range = [x_new["min"][i].item(), x_new["max"][i].item()]
        var.default_value = x_new["default"][i].item()
    return updated_variables

TorchModule

Bases: Module

Wrapper to allow a LUME TorchModel to be used like a torch.nn.Module.

As the base model within the TorchModel is assumed to be fixed during instantiation, so is the TorchModule.

Source code in lume_model/models/torch_module.py
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
class TorchModule(torch.nn.Module):
    """Wrapper to allow a LUME TorchModel to be used like a torch.nn.Module.

    As the base model within the TorchModel is assumed to be fixed during instantiation,
    so is the TorchModule.
    """
    def __init__(
        self,
        *args,
        model: TorchModel = None,
        input_order: list[str] = None,
        output_order: list[str] = None,
    ):
        """Initializes TorchModule.

        Args:
            *args: Accepts a single argument which is the model configuration as dictionary, YAML or JSON
              formatted string or file path.

        Keyword Args:
            model: The TorchModel instance to wrap around. If config is None, this has to be defined.
            input_order: Input names in the order they are passed to the model. If None, the input order of the
              TorchModel is used.
            output_order: Output names in the order they are returned by the model. If None, the output order of
              the TorchModel is used.
        """
        if all(arg is None for arg in [*args, model]):
            raise ValueError("Either a YAML string has to be given or model has to be defined.")
        super().__init__()
        if len(args) == 1:
            if not all(v is None for v in [model, input_order, output_order]):
                raise ValueError("Cannot specify YAML string and keyword arguments for TorchModule init.")
            model_fields = {f"model.{k}": v for k, v in TorchModel.model_fields.items()}
            kwargs = parse_config(args[0], model_fields)
            kwargs["model"] = TorchModel(kwargs["model"])
            self.__init__(**kwargs)
        elif len(args) > 1:
            raise ValueError(
                "Arguments to TorchModule must be either a single YAML string or keyword arguments."
            )
        else:
            self._model = model
            self._input_order = input_order
            self._output_order = output_order
            self.register_module("base_model", self._model.model)
            for i, input_transformer in enumerate(self._model.input_transformers):
                self.register_module(f"input_transformers_{i}", input_transformer)
            for i, output_transformer in enumerate(self._model.output_transformers):
                self.register_module(f"output_transformers_{i}", output_transformer)
            if not model.model.training:  # TorchModel defines train/eval mode
                self.eval()

    @property
    def model(self):
        return self._model

    @property
    def input_order(self):
        if self._input_order is None:
            return self._model.input_names
        else:
            return self._input_order

    @property
    def output_order(self):
        if self._output_order is None:
            return self._model.output_names
        else:
            return self._output_order

    def forward(self, x: torch.Tensor):
        # input shape: [n_batch, n_samples, n_dim]
        x = self._validate_input(x)
        model_input = self._tensor_to_dictionary(x)
        y_model = self.evaluate_model(model_input)
        y_model = self.manipulate_output(y_model)
        # squeeze for use as prior mean in botorch GPs
        y = self._dictionary_to_tensor(y_model).squeeze()
        return y

    def yaml(
            self,
            base_key: str = "",
            file_prefix: str = "",
            save_models: bool = False,
    ) -> str:
        """Serializes the object and returns a YAML formatted string defining the TorchModule instance.

        Args:
            base_key: Base key for serialization.
            file_prefix: Prefix for generated filenames.
            save_models: Determines whether models are saved to file.

        Returns:
            YAML formatted string defining the TorchModule instance.
        """
        d = {}
        for k, v in inspect.signature(TorchModule.__init__).parameters.items():
            if k not in ["self", "args", "model"]:
                d[k] = getattr(self, k)
        output = json.loads(
            json.dumps(recursive_serialize(d, base_key, file_prefix, save_models))
        )
        model_output = json.loads(
            self._model.to_json(
                base_key=base_key,
                file_prefix=file_prefix,
                save_models=save_models,
            )
        )
        output["model"] = model_output
        # create YAML formatted string
        s = yaml.dump({"model_class": self.__class__.__name__} | output,
                      default_flow_style=None, sort_keys=False)
        return s

    def dump(
            self,
            file: Union[str, os.PathLike],
            save_models: bool = True,
            base_key: str = "",
    ):
        """Returns and optionally saves YAML formatted string defining the model.

        Args:
            file: File path to which the YAML formatted string and corresponding files are saved.
            base_key: Base key for serialization.
            save_models: Determines whether models are saved to file.
        """
        file_prefix = os.path.splitext(file)[0]
        with open(file, "w") as f:
            f.write(
                self.yaml(
                    save_models=save_models,
                    base_key=base_key,
                    file_prefix=file_prefix,
                )
            )

    def evaluate_model(self, x: dict[str, torch.Tensor]):
        """Placeholder method to modify model calls."""
        return self._model.evaluate(x)

    def manipulate_output(self, y_model: dict[str, torch.Tensor]):
        """Placeholder method to modify the model output."""
        return y_model

    def _tensor_to_dictionary(self, x: torch.Tensor):
        input_dict = {}
        for idx, input_name in enumerate(self.input_order):
            input_dict[input_name] = x[..., idx].unsqueeze(-1)
        return input_dict

    def _dictionary_to_tensor(self, y_model: dict[str, torch.Tensor]):
        output_tensor = torch.stack(
            [y_model[output_name].unsqueeze(-1) for output_name in self.output_order], dim=-1
        )
        return output_tensor

    @staticmethod
    def _validate_input(x: torch.Tensor) -> torch.Tensor:
        if x.dim() <= 1:
            raise ValueError(
                f"Expected input dim to be at least 2 ([n_samples, n_features]), received: {tuple(x.shape)}"
            )
        else:
            return x

__init__(*args, model=None, input_order=None, output_order=None)

Initializes TorchModule.

Parameters:

Name Type Description Default
*args

Accepts a single argument which is the model configuration as dictionary, YAML or JSON formatted string or file path.

()

Other Parameters:

Name Type Description
model TorchModel

The TorchModel instance to wrap around. If config is None, this has to be defined.

input_order list[str]

Input names in the order they are passed to the model. If None, the input order of the TorchModel is used.

output_order list[str]

Output names in the order they are returned by the model. If None, the output order of the TorchModel is used.

Source code in lume_model/models/torch_module.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
def __init__(
    self,
    *args,
    model: TorchModel = None,
    input_order: list[str] = None,
    output_order: list[str] = None,
):
    """Initializes TorchModule.

    Args:
        *args: Accepts a single argument which is the model configuration as dictionary, YAML or JSON
          formatted string or file path.

    Keyword Args:
        model: The TorchModel instance to wrap around. If config is None, this has to be defined.
        input_order: Input names in the order they are passed to the model. If None, the input order of the
          TorchModel is used.
        output_order: Output names in the order they are returned by the model. If None, the output order of
          the TorchModel is used.
    """
    if all(arg is None for arg in [*args, model]):
        raise ValueError("Either a YAML string has to be given or model has to be defined.")
    super().__init__()
    if len(args) == 1:
        if not all(v is None for v in [model, input_order, output_order]):
            raise ValueError("Cannot specify YAML string and keyword arguments for TorchModule init.")
        model_fields = {f"model.{k}": v for k, v in TorchModel.model_fields.items()}
        kwargs = parse_config(args[0], model_fields)
        kwargs["model"] = TorchModel(kwargs["model"])
        self.__init__(**kwargs)
    elif len(args) > 1:
        raise ValueError(
            "Arguments to TorchModule must be either a single YAML string or keyword arguments."
        )
    else:
        self._model = model
        self._input_order = input_order
        self._output_order = output_order
        self.register_module("base_model", self._model.model)
        for i, input_transformer in enumerate(self._model.input_transformers):
            self.register_module(f"input_transformers_{i}", input_transformer)
        for i, output_transformer in enumerate(self._model.output_transformers):
            self.register_module(f"output_transformers_{i}", output_transformer)
        if not model.model.training:  # TorchModel defines train/eval mode
            self.eval()

dump(file, save_models=True, base_key='')

Returns and optionally saves YAML formatted string defining the model.

Parameters:

Name Type Description Default
file Union[str, PathLike]

File path to which the YAML formatted string and corresponding files are saved.

required
base_key str

Base key for serialization.

''
save_models bool

Determines whether models are saved to file.

True
Source code in lume_model/models/torch_module.py
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
def dump(
        self,
        file: Union[str, os.PathLike],
        save_models: bool = True,
        base_key: str = "",
):
    """Returns and optionally saves YAML formatted string defining the model.

    Args:
        file: File path to which the YAML formatted string and corresponding files are saved.
        base_key: Base key for serialization.
        save_models: Determines whether models are saved to file.
    """
    file_prefix = os.path.splitext(file)[0]
    with open(file, "w") as f:
        f.write(
            self.yaml(
                save_models=save_models,
                base_key=base_key,
                file_prefix=file_prefix,
            )
        )

evaluate_model(x)

Placeholder method to modify model calls.

Source code in lume_model/models/torch_module.py
152
153
154
def evaluate_model(self, x: dict[str, torch.Tensor]):
    """Placeholder method to modify model calls."""
    return self._model.evaluate(x)

manipulate_output(y_model)

Placeholder method to modify the model output.

Source code in lume_model/models/torch_module.py
156
157
158
def manipulate_output(self, y_model: dict[str, torch.Tensor]):
    """Placeholder method to modify the model output."""
    return y_model

yaml(base_key='', file_prefix='', save_models=False)

Serializes the object and returns a YAML formatted string defining the TorchModule instance.

Parameters:

Name Type Description Default
base_key str

Base key for serialization.

''
file_prefix str

Prefix for generated filenames.

''
save_models bool

Determines whether models are saved to file.

False

Returns:

Type Description
str

YAML formatted string defining the TorchModule instance.

Source code in lume_model/models/torch_module.py
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
def yaml(
        self,
        base_key: str = "",
        file_prefix: str = "",
        save_models: bool = False,
) -> str:
    """Serializes the object and returns a YAML formatted string defining the TorchModule instance.

    Args:
        base_key: Base key for serialization.
        file_prefix: Prefix for generated filenames.
        save_models: Determines whether models are saved to file.

    Returns:
        YAML formatted string defining the TorchModule instance.
    """
    d = {}
    for k, v in inspect.signature(TorchModule.__init__).parameters.items():
        if k not in ["self", "args", "model"]:
            d[k] = getattr(self, k)
    output = json.loads(
        json.dumps(recursive_serialize(d, base_key, file_prefix, save_models))
    )
    model_output = json.loads(
        self._model.to_json(
            base_key=base_key,
            file_prefix=file_prefix,
            save_models=save_models,
        )
    )
    output["model"] = model_output
    # create YAML formatted string
    s = yaml.dump({"model_class": self.__class__.__name__} | output,
                  default_flow_style=None, sort_keys=False)
    return s