phlower.nn.InterpolatorPresetGroupModule¶
- class phlower.nn.InterpolatorPresetGroupModule(source_position_name, target_position_name, source_position_from_field=True, target_position_from_field=True, variable_names=None, input_to_output_map=None, source_field_name=None, target_field_name=None, n_head=1, weight_name=None, length_scale=0.1, normalize_row=True, conservative=False, bidirectional=False, k_nns=9, d_nns=None, learnable_length_scale=False, learnable_weight_transpose=False, learnable_weight_distance=False, recompute_distance_k_nns_if_grad_enabled=False, recompute_distance_d_nns_if_grad_enabled=False, unbatch=True, epsilon=1e-05)[source]¶
Bases:
Module,IGenericPhlowerCoreModule[IPhlowerPresetGroupParameters,IPhlowerTensorCollections]Semi-Lagrangian Attention is a neural network module for convection-dominated problems based on the Semi-Lagrangian scheme.
- Parameters:
source_position_name (str) – Name to be used for source position data.
target_position_name (str) – Name to be used for target position data.
source_position_from_field (bool = True) – If False, take source position from tensors rather than field.
target_position_from_field (bool = True) – If False, take target position from tensors rather than field.
variable_names (list[str] | None = None) – Variable names to be interpolated. If not fed, all data are interpolated.
source_field_name (str | None = None) – Name of the source to use for unbatching source fields.
target_field_name (str | None = None) – Name of the target to use for unbatching source fields.
n_head (int = 1) – The number of head for the multihead attention mechanism.
weight_name (str | None = None) – Name to be used for weight data.
length_scale (float = 1.0e-1) – Length scale to normalize distance-based attention.
normalize_row (bool = True) – If True, normalize the attention matrix in the row direction.
conservative (bool = False) – If True, normalize the attention matrix further to achieve conservation.
bidirectional (bool = False) – If True, compute bidirectional interpolation assuming both source and target represent the same shape.
k_nns (int = 9) – The number to be used for k nearest neighbor search. It should be larger than 1.
d_nns (float | None = None) – If fed, perform nearest neighbor search based on the fed distance in addition to the k-NN search.
learnable_length_scale (bool = False) – If True, learn length scale. In that case, the length_scale parameter above is used to initialize trainable weights.
learnable_weight_transpose (bool = False) – If True, learn weight for the transposed attention matrix.
learnable_weight_distance (bool = False) – If True, learn weight for the distance based attention matrix.
recompute_distance_k_nns_if_grad_enabled (bool = False) – If True and torch.is_grad_enabled, re-compute distance for k-NN attention matrix to retain gradients.
recompute_distance_d_nns_if_grad_enabled (bool = False) – If True and torch.is_grad_enabled, re-compute distance for distance attention matrix to retain gradient.
unbatch (bool = True) – If True, unbatch the input data and field.
epsilon (float = 1.0e-5) – Epsilon value to decimate small numbers.
input_to_output_map (dict[str, str] | None)
Examples
>>> interpolator = InterpolatorPresetGroupModule( ... source_position_name="source_position", ... target_position_name="target_position", ... ) >>> output = interpolator(data, field_data=field_data)
Methods
forward(data, *, field_data, **kwards)forward function which overload torch.nn.Module
from_setting(setting)Create InterpolatorPresetGroupModule from InterpolatorPresetGroupSetting instance.
Return neural network name
Attributes
T_destinationcall_super_initdump_patchessource_field_name_for_unbatchtarget_field_name_for_unbatchtraining- forward(data, *, field_data, **kwards)[source]¶
forward function which overload torch.nn.Module
- Parameters:
data (IPhlowerTensorCollections) – IPhlowerTensorCollections data which receives from predecessors
field_data (ISimulationField | None) – ISimulationField | None Constant information through training or prediction
- Returns:
Tensor object
- Return type:
PhlowerTensor