[docs]classDirichlet(IPhlowerCoreModule,torch.nn.Module):"""Dirichlet is a neural network module that overwrites values with that of dirichlet field. Parameters ---------- activation: str Name of the activation function to apply to the output. dirichlet_name: str Name of the dirichlet field. nodes: list[int] | None (optional) List of feature dimension sizes (The last value of tensor shape). Defaults to None. Examples -------- >>> dirichlet = Dirichlet(activation="relu", dirichlet_name="dirichlet") >>> dirichlet(data) """
[docs]defforward(self,data:IPhlowerTensorCollections,*,field_data:ISimulationField|None=None,**kwards,)->PhlowerTensor:"""forward function which overloads torch.nn.Module Args: data: IPhlowerTensorCollections data which receives from predecessors supports: dict[str, PhlowerTensor] | None Graph object. Defaults to None. Dirichlet will not use it. Returns: PhlowerTensor: Tensor object """dirichlet=data[self._dirichlet_name]value_names=list(filter(lambdaname:name!=self._dirichlet_name,data.keys()))iflen(value_names)!=1:raiseValueError("Dirichlet value cannot be detected. "f"Candidates: {value_names}")value_name=value_names[0]dirichlet_filter=torch.isnan(dirichlet)ans=torch.where(dirichlet_filter,data[value_name],dirichlet)returnself._activation_func(ans)