phlower.nn.TransolverAttention

class phlower.nn.TransolverAttention(nodes, heads=8, slice_num=32, dropout=0.0, unbatch_key=None)[source]

Bases: IGenericPhlowerCoreModule[IPhlowerLayerParameters, PhlowerTensor], Module

TransolverAttention is a neural network module that performs physics-attention on the input tensor, as used in Transolver.

It performs attention in three steps: 1. Slice: Project features into physics-aware/slice tokens. 2. Attention: Compute attention among physics-aware/slice tokens. 3. Deslice: Project physics-aware/slice tokens back to original space.

Ref: https://arxiv.org/abs/2402.02366

Parameters:
  • nodes (list[int]) – List of feature dimension sizes (The last value of tensor shape). [input_dim, inner_dim, output_dim]

  • heads (int) – Number of attention heads. Default is 8.

  • slice_num (int) – Number of slice tokens. Default is 32.

  • dropout (float) – Dropout rate. Default is 0.0.

  • unbatch_key (str | None) – Key of the unbatch operation.

Examples

>>> attention = TransolverAttention(
...     nodes=[256, 256, 256],
...     heads=8,
...     slice_num=32,
...     dropout=0.0,
... )
>>> attention(data)

Methods

forward(data, *[, field_data])

forward function which overloads torch.nn.Module

from_setting(setting)

Create TransolverAttention from setting object

get_nn_name()

Return name of TransolverAttention

Attributes

T_destination

call_super_init

dump_patches

training

forward(data, *, field_data=None, **kwards)[source]

forward function which overloads torch.nn.Module

Parameters:
  • data (IPhlowerTensorCollections) – IPhlowerTensorCollections data which receives from predecessors

  • field_data (ISimulationField | None) – ISimulationField | None Constant information through training or prediction

Returns:

Tensor object

Return type:

PhlowerTensor

classmethod from_setting(setting)[source]

Create TransolverAttention from setting object

Parameters:

setting (TransolverAttentionSetting) – setting object

Returns:

TransolverAttention

Return type:

Self

classmethod get_nn_name()[source]

Return name of TransolverAttention

Returns:

name

Return type:

str