Source code for phlower.nn._core_modules._dirichlet
from __future__ import annotations
import torch
from phlower_tensor import ISimulationField, PhlowerTensor
from phlower_tensor.collections import IPhlowerTensorCollections
from phlower.nn._core_modules import _utils
from phlower.nn._interface_module import (
IPhlowerCoreModule,
IReadonlyReferenceGroup,
)
from phlower.settings._module_settings import DirichletSetting
[docs]
class Dirichlet(IPhlowerCoreModule, torch.nn.Module):
"""Dirichlet is a neural network module that overwrites values
with that of dirichlet field.
Parameters
----------
activation: str
Name of the activation function to apply to the output.
dirichlet_name: str
Name of the dirichlet field.
nodes: list[int] | None (optional)
List of feature dimension sizes (The last value of tensor shape).
Defaults to None.
Examples
--------
>>> dirichlet = Dirichlet(activation="relu", dirichlet_name="dirichlet")
>>> dirichlet(data)
"""
[docs]
@classmethod
def from_setting(cls, setting: DirichletSetting) -> Dirichlet:
"""Create Dirichlet from setting object
Args:
setting (DirichletSetting): setting object
Returns:
Self: Dirichlet
"""
return Dirichlet(**setting.__dict__)
[docs]
@classmethod
def get_nn_name(cls) -> str:
"""Return name of Dirichlet
Returns:
str: name
"""
return "Dirichlet"
@classmethod
def need_reference(cls) -> bool:
return False
def __init__(
self, activation: str, dirichlet_name: str, nodes: list[int] = None
):
super().__init__()
self._nodes = nodes
self._activation_name = activation
self._activation_func = _utils.ActivationSelector.select(activation)
self._dirichlet_name = dirichlet_name
def resolve(
self, *, parent: IReadonlyReferenceGroup | None = None, **kwards
) -> None: ...
def get_reference_name(self) -> str | None:
return None
[docs]
def forward(
self,
data: IPhlowerTensorCollections,
*,
field_data: ISimulationField | None = None,
**kwards,
) -> PhlowerTensor:
"""forward function which overloads torch.nn.Module
Args:
data: IPhlowerTensorCollections
data which receives from predecessors
supports: dict[str, PhlowerTensor] | None
Graph object. Defaults to None. Dirichlet will not use it.
Returns:
PhlowerTensor: Tensor object
"""
dirichlet = data[self._dirichlet_name]
value_names = list(
filter(lambda name: name != self._dirichlet_name, data.keys())
)
if len(value_names) != 1:
raise ValueError(
f"Dirichlet value cannot be detected. Candidates: {value_names}"
)
value_name = value_names[0]
dirichlet_filter = torch.isnan(dirichlet)
ans = torch.where(dirichlet_filter, data[value_name], dirichlet)
return self._activation_func(ans)