phlower.nn.ConjugateGradientSolver

class phlower.nn.ConjugateGradientSolver(matrix_name, x0_name=None, dirichlet_name=None, rtol=1e-05, atol=0.0, maxiter=None, batch_solve=True, force_cpu=False, log_level='warning', **kwards)[source]

Bases: Module, IGenericPhlowerCoreModule[IPhlowerLayerParameters, PhlowerTensor]

Solver based on the conjugate gradient (CG) method.

Parameters:
  • matrix_name (str) – Name of the sparse matrix.

  • x0_name (str | None = None) – Name of the initial guess of the solution.

  • dirichlet_name (str | None = None) – Name of the Dirichlet feature.

  • rtol (float = 1e-5) – Relative tolerance.

  • atol (float = 0.0) – Absolute tolerance.

  • maxiter (int | None = None) – Maximum number of iterations.

  • batch_solve (bool = True) – If True, solve multiple systems in a batched manner.

  • force_cpu (bool = False) – If True, force CPU computation even if the input is GPU.

  • log_level (str = "warning") – Log level for the massage when non convergence.

Examples

>>> spmm = ConjugateGradientSolver(matrix_name="A")
>>> spmm(data, field_data=field)

Methods

forward(data, *, field_data, **kwards)

forward function which overload torch.nn.Module

from_setting(setting)

Create ConjugateGradientSolver from ConjugateGradientSolverSetting object.

get_nn_name()

Return neural network name

Attributes

T_destination

call_super_init

dump_patches

training

forward(data, *, field_data, **kwards)[source]

forward function which overload torch.nn.Module

Parameters:
  • data (IPhlowerTensorCollections) – IPhlowerTensorCollections data which receives from predecessors

  • field_data (ISimulationField) – ISimulationField | None Constant information through training or prediction

Returns:

Tensor object

Return type:

PhlowerTensor

classmethod from_setting(setting)[source]

Create ConjugateGradientSolver from ConjugateGradientSolverSetting object.

Parameters:

setting (ConjugateGradientSolverSetting) – setting object for ConjugateGradientSolver

Returns:

ConjugateGradientSolver model

Return type:

ConjugateGradientSolver

classmethod get_nn_name()[source]

Return neural network name

Returns:

name

Return type:

str