PyTorch DiffEq solver
TorchDiffEqSolver
Bases: ForwardSolver
ODE-based forward solver using the TorchDiffEq library.
This solver integrates tumor growth models over specified timepoints using advanced ODE solvers and handles both radiotherapy and chemotherapy schedules.
Attributes:
Name | Type | Description |
---|---|---|
model |
TumorGrowthModel3D
|
The tumor growth model to solve. |
solver_options |
TorchDiffEqSolverOptions
|
Configuration options for the solver. |
__init__(model, solver_options)
Initializes the TorchDiffEqSolver.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model
|
TumorGrowthModel3D
|
The tumor growth model to solve. |
required |
solver_options
|
TorchDiffEqSolverOptions
|
Configuration options for the solver. |
required |
grid_constructor(func, y0, t)
Constructs a grid of timesteps considering treatment schedules.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
func
|
The ODE function (unused in this method but required by the API). |
required | |
y0
|
Initial state of the system (unused in this method but required by the API). |
required | |
t
|
Original list of timepoints requested by the solver. |
required |
Returns:
Type | Description |
---|---|
Tensor
|
torch.Tensor: Tensor containing refined timepoints for integration. |
solve(timepoints, u_initial)
Solves the tumor growth model over the specified timepoints.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
timepoints
|
List[datetime]
|
List of timepoints at which the solution is desired. |
required |
u_initial
|
Tensor
|
Initial tumor density field. |
required |
Returns:
Type | Description |
---|---|
Tuple[List[datetime], List[Tensor]]
|
Tuple[List[datetime], List[torch.Tensor]]: - A list of datetime objects corresponding to the solution timepoints. - A list of torch.Tensor objects representing the tumor density at each timepoint. |
TorchDiffEqSolverOptions
dataclass
Configuration options for the TorchDiffEqSolver.
Attributes:
Name | Type | Description |
---|---|---|
step_size |
timedelta
|
The integration step size for the solver. |
method |
str
|
The ODE solver method to use (e.g., "rk4", "dopri5"). |
device |
device
|
The device on which to perform computations (e.g., CPU or GPU). |
use_adjoint |
bool
|
Whether to use the adjoint method for memory-efficient backpropagation. |