Source code for phlower.nn._core_modules._reducer
from __future__ import annotations
from functools import reduce
from typing import Self
import torch
from phlower_tensor import ISimulationField, PhlowerTensor
from phlower_tensor.collections import IPhlowerTensorCollections
from phlower_tensor.functionals import broadcast_to
from phlower.nn._core_modules import _utils
from phlower.nn._interface_module import (
IPhlowerCoreModule,
IReadonlyReferenceGroup,
)
from phlower.settings._module_settings import ReducerSetting
[docs]
class Reducer(IPhlowerCoreModule, torch.nn.Module):
"""Reducer is a neural network module that performs a reduction operation
on the input tensor.
Parameters
----------
activation: str
Name of the activation function to apply to the output.
operator: str
Name of the operator to apply to the input tensors. "add" or "mul".
Default is "add".
nodes: list[int]
List of feature dimension sizes (The last value of tensor shape).
Examples
--------
>>> reducer = Reducer(activation="relu", operator="add", nodes=[10, 20, 30])
>>> reducer(data)
"""
_REGISTERED_OPERATORS = {"add": torch.add, "mul": torch.mul}
[docs]
@classmethod
def from_setting(cls, setting: ReducerSetting) -> Self:
"""Create Reducer from setting object
Args:
setting: ReducerSetting
setting object
Returns:
Self: Reducer
"""
return Reducer(**setting.__dict__)
[docs]
@classmethod
def get_nn_name(cls) -> str:
"""Return name of Reducer
Returns:
str: name
"""
return "Reducer"
@classmethod
def need_reference(cls) -> bool:
return False
def __init__(self, activation: str, operator: str, nodes: list[int] = None):
super().__init__()
self._nodes = nodes
self._activation_name = activation
self._activation_func = _utils.ActivationSelector.select(activation)
self._operator_name = operator
self._operator = self._REGISTERED_OPERATORS[self._operator_name]
def resolve(
self, *, parent: IReadonlyReferenceGroup | None = None, **kwards
) -> None: ...
def get_reference_name(self) -> str | None:
return None
[docs]
def forward(
self,
data: IPhlowerTensorCollections,
*,
field_data: ISimulationField | None = None,
**kwards,
) -> PhlowerTensor:
"""forward function which overloads torch.nn.Module
Args:
data: IPhlowerTensorCollections
data which receives from predecessors
field_data: ISimulationField | None
Constant information through training or prediction
Returns:
PhlowerTensor: Tensor object
"""
tensors = tuple(data.values())
ans = tensors[0]
ans = reduce(
lambda x, y: self._operator(*_to_broadcastable_shape(x, y)),
tensors,
)
return self._activation_func(ans)
def _to_broadcastable_shape(
tensor1: PhlowerTensor, tensor2: PhlowerTensor
) -> tuple[PhlowerTensor, PhlowerTensor]:
if len(tensor1.shape) == len(tensor2.shape):
return tensor1, tensor2
tensor_s, tensor_l = sorted([tensor1, tensor2], key=lambda x: len(x.shape))
tensor_s = broadcast_to(tensor_s, tensor_l.shape_pattern)
return tensor_s, tensor_l