Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions docs/source/_rst/_code.rst
Original file line number Diff line number Diff line change
Expand Up @@ -174,10 +174,10 @@ Optimizers and Schedulers
.. toctree::
:titlesonly:

Optimizer <optim/optimizer_interface.rst>
Scheduler <optim/scheduler_interface.rst>
TorchOptimizer <optim/torch_optimizer.rst>
TorchScheduler <optim/torch_scheduler.rst>
Optimizer Interface <optim/optimizer_interface.rst>
Scheduler Interface <optim/scheduler_interface.rst>
Torch Optimizer <optim/torch_optimizer.rst>
Torch Scheduler <optim/torch_scheduler.rst>


Adaptive Functions
Expand Down
6 changes: 3 additions & 3 deletions docs/source/_rst/optim/optimizer_interface.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Optimizer
============
Optimizer Interface
=====================
.. currentmodule:: pina.optim.optimizer_interface

.. autoclass:: pina._src.optim.optimizer_interface.Optimizer
.. autoclass:: pina._src.optim.optimizer_interface.OptimizerInterface
:members:
:show-inheritance:
6 changes: 3 additions & 3 deletions docs/source/_rst/optim/scheduler_interface.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Scheduler
=============
Scheduler Interface
=====================
.. currentmodule:: pina.optim.scheduler_interface

.. autoclass:: pina._src.optim.scheduler_interface.Scheduler
.. autoclass:: pina._src.optim.scheduler_interface.SchedulerInterface
:members:
:show-inheritance:
2 changes: 1 addition & 1 deletion docs/source/_rst/optim/torch_optimizer.rst
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
TorchOptimizer
Torch Optimizer
===============
.. currentmodule:: pina.optim.torch_optimizer

Expand Down
2 changes: 1 addition & 1 deletion docs/source/_rst/optim/torch_scheduler.rst
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
TorchScheduler
Torch Scheduler
===============
.. currentmodule:: pina.optim.torch_scheduler

Expand Down
25 changes: 16 additions & 9 deletions pina/_src/optim/optimizer_interface.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,30 @@
"""Module for the PINA Optimizer."""
"""Module for the Optimizer Interface."""

from abc import ABCMeta, abstractmethod


class Optimizer(metaclass=ABCMeta):
class OptimizerInterface(metaclass=ABCMeta):
"""
Abstract base class for defining an optimizer. All specific optimizers
should inherit form this class and implement the required methods.
Abstract interface for all optimizers.
"""

@property
@abstractmethod
def instance(self):
def hook(self, parameters):
"""
Abstract property to retrieve the optimizer instance.
Execute custom logic associated with the optimizer instance.

This method is intended to encapsulate any additional behavior that
should be triggered during the optimization process.

:param dict parameters: The parameters of the model to be optimized.
"""

@property
@abstractmethod
def hook(self):
def instance(self):
"""
Abstract method to define the hook logic for the optimizer.
The underlying optimizer object.

:return: The optimizer instance.
:rtype: object
"""
26 changes: 17 additions & 9 deletions pina/_src/optim/scheduler_interface.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,31 @@
"""Module for the PINA Scheduler."""
"""Module for the Scheduler Interface."""

from abc import ABCMeta, abstractmethod


class Scheduler(metaclass=ABCMeta):
class SchedulerInterface(metaclass=ABCMeta):
"""
Abstract base class for defining a scheduler. All specific schedulers should
inherit form this class and implement the required methods.
Abstract interface for all schedulers.
"""

@property
@abstractmethod
def instance(self):
def hook(self, optimizer):
"""
Abstract property to retrieve the scheduler instance.
Execute custom logic associated with the scheduler instance.

This method is intended to encapsulate any additional behavior that
should be triggered during the optimization process.

:param OptimizerInterface optimizer: The optimizer instance associated
with the scheduler.
"""

@property
@abstractmethod
def hook(self):
def instance(self):
"""
Abstract method to define the hook logic for the scheduler.
The underlying scheduler object.

:return: The scheduler instance.
:rtype: object
"""
33 changes: 22 additions & 11 deletions pina/_src/optim/torch_optimizer.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,46 @@
"""Module for the PINA Torch Optimizer"""
"""Module for wrapping PyTorch optimizers."""

import torch

from pina._src.core.utils import check_consistency
from pina._src.optim.optimizer_interface import Optimizer
from pina._src.optim.optimizer_interface import OptimizerInterface


class TorchOptimizer(Optimizer):
class TorchOptimizer(OptimizerInterface):
"""
A wrapper class for using PyTorch optimizers.
The wrapper class for PyTorch optimizers.

This class wraps a ``torch.optim.Optimizer`` class and defers its
instantiation until runtime. It enables a consistent interface across
different optimizer backends while leveraging PyTorch’s optimization
algorithms.
"""

def __init__(self, optimizer_class, **kwargs):
"""
Initialization of the :class:`TorchOptimizer` class.

:param torch.optim.Optimizer optimizer_class: A
:class:`torch.optim.Optimizer` class.
:param dict kwargs: Additional parameters passed to ``optimizer_class``,
see more
:param torch.optim.Optimizer optimizer_class: The subclass of
``torch.optim.Optimizer`` to be instantiated.
:param dict kwargs: Additional keyword arguments forwarded to the
optimizer constructor. See more
`here <https://pytorch.org/docs/stable/optim.html#algorithms>`_.
:raises ValueError: If ``optimizer_class`` is not a subclass of
``torch.optim.Optimizer``.
"""
# Check consistency
check_consistency(optimizer_class, torch.optim.Optimizer, subclass=True)

# Initialize attributes
self.optimizer_class = optimizer_class
self.kwargs = kwargs
self._optimizer_instance = None

def hook(self, parameters):
"""
Initialize the optimizer instance with the given parameters.
Execute custom logic associated with the optimizer instance.

This method is intended to encapsulate any additional behavior that
should be triggered during the optimization process.

:param dict parameters: The parameters of the model to be optimized.
"""
Expand All @@ -40,7 +51,7 @@ def hook(self, parameters):
@property
def instance(self):
"""
Get the optimizer instance.
The underlying optimizer object.

:return: The optimizer instance.
:rtype: torch.optim.Optimizer
Expand Down
51 changes: 29 additions & 22 deletions pina/_src/optim/torch_scheduler.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,35 @@
"""Module for the PINA Torch Optimizer"""

try:
from torch.optim.lr_scheduler import LRScheduler # torch >= 2.0
except ImportError:
from torch.optim.lr_scheduler import (
_LRScheduler as LRScheduler,
) # torch < 2.0
"""Module for wrapping PyTorch schedulers."""

from torch.optim.lr_scheduler import LRScheduler
from pina._src.core.utils import check_consistency
from pina._src.optim.optimizer_interface import Optimizer
from pina._src.optim.scheduler_interface import Scheduler
from pina._src.optim.optimizer_interface import OptimizerInterface
from pina._src.optim.scheduler_interface import SchedulerInterface


class TorchScheduler(Scheduler):
class TorchScheduler(SchedulerInterface):
"""
A wrapper class for using PyTorch schedulers.
The wrapper class for PyTorch schedulers.

This class wraps a ``torch.optim.lr_scheduler.LRScheduler`` class and defers
its instantiation until runtime, once the optimizer instance is available.
"""

def __init__(self, scheduler_class, **kwargs):
"""
Initialization of the :class:`TorchScheduler` class.

:param torch.optim.LRScheduler scheduler_class: A
:class:`torch.optim.LRScheduler` class.
:param dict kwargs: Additional parameters passed to ``scheduler_class``,
see more
`here <https://pytorch.org/docs/stable/optim.html#algorithms>_`.
:param torch.optim.LRScheduler scheduler_class: The subclass of
``torch.optim.lr_scheduler.LRScheduler`` to be instantiated.
:param dict kwargs: Additional keyword arguments forwarded to the
scheduler constructor. See more
`here <https://pytorch.org/docs/stable/optim.html#algorithms>`_.
:raises ValueError: If ``scheduler_class`` is not a subclass of
``torch.optim.lr_scheduler.LRScheduler``.
"""
# Check consistency
check_consistency(scheduler_class, LRScheduler, subclass=True)

# Initialize attributes
self.scheduler_class = scheduler_class
self.kwargs = kwargs
self._scheduler_instance = None
Expand All @@ -37,19 +38,25 @@ def hook(self, optimizer):
"""
Initialize the scheduler instance with the given parameters.

:param dict parameters: The parameters of the optimizer.
:param OptimizerInterface optimizer: The optimizer instance associated
with the scheduler.
:raises ValueError: If ``optimizer`` is not an instance of
:class:`OptimizerInterface`.
"""
check_consistency(optimizer, Optimizer)
# Check consistency
check_consistency(optimizer, OptimizerInterface)

# Initialize the scheduler instance
self._scheduler_instance = self.scheduler_class(
optimizer.instance, **self.kwargs
)

@property
def instance(self):
"""
Get the scheduler instance.
The underlying scheduler object.

:return: The scheduelr instance.
:rtype: torch.optim.LRScheduler
:return: The scheduler instance.
:rtype: torch.optim.lr_scheduler.LRScheduler
"""
return self._scheduler_instance
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,10 @@ def __init__(
:param torch.nn.Module loss: The loss function to be minimized.
If ``None``, the :class:`torch.nn.MSELoss` loss is used.
Default is ``None``.
:param Optimizer optimizer: The optimizer to be used.
:param OptimizerInterface optimizer: The optimizer to be used.
If ``None``, the :class:`torch.optim.Adam` optimizer is used.
Default is ``None``.
:param Scheduler scheduler: Learning rate scheduler.
:param SchedulerInterface scheduler: Learning rate scheduler.
If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR`
scheduler is used. Default is ``None``.
:param WeightingInterface weighting: The weighting schema to be used.
Expand Down
4 changes: 2 additions & 2 deletions pina/_src/solver/ensemble_solver/ensemble_pinn.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,10 @@ def __init__(
:param torch.nn.Module loss: The loss function to be minimized.
If ``None``, the :class:`torch.nn.MSELoss` loss is used.
Default is ``None``.
:param Optimizer optimizer: The optimizer to be used.
:param OptimizerInterface optimizers: The optimizers to be used.
If ``None``, the :class:`torch.optim.Adam` optimizer is used.
Default is ``None``.
:param Scheduler scheduler: Learning rate scheduler.
:param SchedulerInterface schedulers: Learning rate schedulers.
If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR`
scheduler is used. Default is ``None``.
:param WeightingInterface weighting: The weighting schema to be used.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,10 @@ def __init__(

:param BaseProblem problem: The problem to be solved.
:param torch.nn.Module models: The neural network models to be used.
:param Optimizer optimizer: The optimizer to be used.
:param OptimizerInterface optimizers: The optimizers to be used.
If ``None``, the :class:`torch.optim.Adam` optimizer is used.
Default is ``None``.
:param Scheduler scheduler: Learning rate scheduler.
:param SchedulerInterface schedulers: Learning rate schedulers.
If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR`
scheduler is used. Default is ``None``.
:param WeightingInterface weighting: The weighting schema to be used.
Expand Down
4 changes: 2 additions & 2 deletions pina/_src/solver/ensemble_solver/ensemble_supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,10 @@ def __init__(
:param torch.nn.Module loss: The loss function to be minimized.
If ``None``, the :class:`torch.nn.MSELoss` loss is used.
Default is ``None``.
:param Optimizer optimizer: The optimizer to be used.
:param OptimizerInterface optimizers: The optimizers to be used.
If ``None``, the :class:`torch.optim.Adam` optimizer is used.
Default is ``None``.
:param Scheduler scheduler: Learning rate scheduler.
:param SchedulerInterface schedulers: Learning rate schedulers.
If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR`
scheduler is used. Default is ``None``.
:param WeightingInterface weighting: The weighting schema to be used.
Expand Down
24 changes: 12 additions & 12 deletions pina/_src/solver/garom.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,18 +48,18 @@ def __init__(
:param torch.nn.Module loss: The loss function to be minimized.
If ``None``, :class:`~pina.loss.power_loss.PowerLoss` with ``p=1``
is used. Default is ``None``.
:param Optimizer optimizer_generator: The optimizer for the generator.
If ``None``, the :class:`torch.optim.Adam` optimizer is used.
Default is ``None``.
:param Optimizer optimizer_discriminator: The optimizer for the
:param OptimizerInterface optimizer_generator: The optimizer for the
generator. If ``None``, the :class:`torch.optim.Adam` optimizer is
used. Default is ``None``.
:param OptimizerInterface optimizer_discriminator: The optimizer for the
discriminator. If ``None``, the :class:`torch.optim.Adam`
optimizer is used. Default is ``None``.
:param Scheduler scheduler_generator: The learning rate scheduler for
the generator.
:param SchedulerInterface scheduler_generator: The learning rate
scheduler for the generator.
If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR`
scheduler is used. Default is ``None``.
:param Scheduler scheduler_discriminator: The learning rate scheduler
for the discriminator.
:param SchedulerInterface scheduler_discriminator: The learning rate
scheduler for the discriminator.
If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR`
scheduler is used. Default is ``None``.
:param float gamma: Ratio of expected loss for generator and
Expand Down Expand Up @@ -328,7 +328,7 @@ def optimizer_generator(self):
The optimizer for the generator.

:return: The optimizer for the generator.
:rtype: Optimizer
:rtype: OptimizerInterface
"""
return self.optimizers[0]

Expand All @@ -338,7 +338,7 @@ def optimizer_discriminator(self):
The optimizer for the discriminator.

:return: The optimizer for the discriminator.
:rtype: Optimizer
:rtype: OptimizerInterface
"""
return self.optimizers[1]

Expand All @@ -348,7 +348,7 @@ def scheduler_generator(self):
The scheduler for the generator.

:return: The scheduler for the generator.
:rtype: Scheduler
:rtype: SchedulerInterface
"""
return self.schedulers[0]

Expand All @@ -358,6 +358,6 @@ def scheduler_discriminator(self):
The scheduler for the discriminator.

:return: The scheduler for the discriminator.
:rtype: Scheduler
:rtype: SchedulerInterface
"""
return self.schedulers[1]
Loading
Loading