Note
Go to the end to download the full example code.
Optimize the shape of volumetric mesh¶
graphlow is a powerful tool for shape optimization, leveraging differentiable tensor-based geometric computations.
This example shows how to minimize surface area while constraining volume changes and vertex deformation.

Import necessary modules including graphlow
.
import itertools
import numpy as np
import pyvista as pv
import torch
import graphlow
Prepare a volumetric mesh¶
First, we define a function to generate grid data as the example mesh. If you wish to use your own mesh, you can skip this step.
def generate_grid(ni: int, nj: int, nk: int) -> pv.UnstructuredGrid:
n_cells = (ni - 1) * (nj - 1) * (nk - 1)
x = np.arange(ni, dtype=np.float32)
y = np.arange(nj, dtype=np.float32)
z = np.arange(nk, dtype=np.float32)
X, Y, Z = np.meshgrid(x, y, z, indexing="ij")
points = np.array([X.ravel(), Y.ravel(), Z.ravel()]).T
indices = np.arange(ni * nj * nk).reshape(ni, nj, nk)
hex_lists = []
for k in range(nk - 1):
for j in range(nj - 1):
for i in range(ni - 1):
# hex
v0 = indices[i, j, k]
v1 = indices[i + 1, j, k]
v2 = indices[i + 1, j + 1, k]
v3 = indices[i, j + 1, k]
v4 = indices[i, j, k + 1]
v5 = indices[i + 1, j, k + 1]
v6 = indices[i + 1, j + 1, k + 1]
v7 = indices[i, j + 1, k + 1]
hex_lists.append([8, v0, v1, v2, v3, v4, v5, v6, v7])
hex_cells = list(itertools.chain.from_iterable(hex_lists))
celltypes = np.full((n_cells), pv.CellType.HEXAHEDRON)
return pv.UnstructuredGrid(hex_cells, celltypes, points)
Define the cost function¶
To minimize surface area with constraints, we defined the cost function as follows:
weight_deformation_constraint = 1.0
weight_volume_constraint = 10.0
def cost_function(
mesh: graphlow.GraphlowMesh,
init_total_volume: float,
init_surface_total_area: float,
) -> torch.Tensor:
deformation = mesh.dict_point_tensor["deformation"]
surface = mesh.extract_surface(pass_point_data=True)
volumes = mesh.compute_volumes()
areas = surface.compute_areas()
total_volume = torch.sum(volumes)
total_area = torch.sum(areas)
if torch.any(volumes < 1e-3 * init_total_volume / mesh.n_cells):
return None
cost_area = total_area / init_surface_total_area
volume_constraint = (
(total_volume - init_total_volume) / init_total_volume
) ** 2
deformation_constraint = torch.mean(deformation * deformation)
return (
cost_area
+ weight_volume_constraint * volume_constraint
+ weight_deformation_constraint * deformation_constraint
)
Visualize¶
Following code generate gif plotter to visualize the result.
def create_gif_plotter(mesh: pv.UnstructuredGrid) -> pv.Plotter:
plotter = pv.Plotter(window_size=[800, 600])
init_mesh = mesh.copy()
plotter.add_mesh(init_mesh, show_edges=True, color="white", opacity=0.1)
plotter.add_mesh(
mesh, show_edges=True, lighting=False, cmap="turbo", opacity=0.8
)
plotter.open_gif("shape_optimization_result.gif")
plotter.show_bounds(init_mesh, location="outer")
plotter.camera_position = "iso"
return plotter
Once the cost function and the visualizing function are determined, the remaining step is to write the optimization code that updates the points and deformations
This is the example of the optimization code.
def optimize_shape(input_mesh: pv.UnstructuredGrid):
# Optimization setting
n_optimization = 2000
print_period = int(n_optimization / 100)
n_hidden = 64
deformation_factor = 1.0
lr = 1e-2
output_activation = torch.nn.Identity()
# Initialize
input_mesh.points = input_mesh.points - np.mean(
input_mesh.points, axis=0, keepdims=True
) # Center mesh position
mesh = graphlow.GraphlowMesh(input_mesh)
init_volumes = mesh.compute_volumes().clone()
init_total_volume = torch.sum(init_volumes)
init_points = mesh.points.clone()
init_surface = mesh.extract_surface()
init_surface_areas = init_surface.compute_areas().clone()
init_surface_total_area = torch.sum(init_surface_areas)
w1 = torch.nn.Parameter(torch.randn(3, n_hidden) / n_hidden**0.5)
w2 = torch.nn.Parameter(torch.randn(n_hidden, 3) / n_hidden**0.5)
params = [w1, w2]
optimizer = torch.optim.Adam(params, lr=lr)
def compute_deformation(points: torch.Tensor) -> torch.Tensor:
hidden = torch.tanh(torch.einsum("np,pq->nq", points, w1))
deformation = output_activation(torch.einsum("np,pq->nq", hidden, w2))
return deformation_factor * deformation
deformation = compute_deformation(init_points)
mesh.dict_point_tensor.update({"deformation": deformation}, overwrite=True)
mesh.copy_features_to_pyvista(overwrite=True)
mesh.pvmesh.points = mesh.points.detach().numpy()
plotter = create_gif_plotter(mesh.pvmesh)
plotter.write_frame()
# Optimization loop
print(f"\ninitial volume: {torch.sum(mesh.compute_volumes()):.5f}")
print(" i, cost")
for i in range(1, n_optimization + 1):
optimizer.zero_grad()
deformation = compute_deformation(init_points)
deformed_points = init_points + deformation
mesh.dict_point_tensor.update(
{"deformation": deformation}, overwrite=True
)
mesh.dict_point_tensor.update(
{"points": deformed_points}, overwrite=True
)
cost = cost_function(mesh, init_total_volume, init_surface_total_area)
if cost is None:
deformation_factor = deformation_factor * 0.9
print(f"update deformation_factor: {deformation_factor}")
continue
if i % print_period == 0:
print(f"{i:6d}, {cost:.5e}")
mesh.copy_features_to_pyvista(overwrite=True)
mesh.pvmesh.points = mesh.points.detach().numpy()
plotter.write_frame()
cost.backward()
optimizer.step()
plotter.close()
Run the optimization¶
Finally, we define the main function to run the optimization.
def main():
mesh = generate_grid(10, 10, 10)
optimize_shape(mesh)
if __name__ == "__main__":
main()

initial volume: 729.00000
i, cost
20, 9.97958e-01
40, 9.60497e-01
60, 9.44003e-01
80, 9.32981e-01
100, 9.26768e-01
120, 9.23295e-01
140, 9.21050e-01
160, 9.19457e-01
180, 9.18257e-01
200, 9.17275e-01
220, 9.16411e-01
240, 9.15622e-01
260, 9.14886e-01
280, 9.14194e-01
300, 9.13542e-01
320, 9.17020e-01
340, 9.12734e-01
360, 9.12301e-01
380, 9.11764e-01
400, 9.11228e-01
420, 9.10658e-01
440, 9.10567e-01
460, 9.13917e-01
480, 9.09154e-01
500, 9.08657e-01
520, 9.08090e-01
540, 9.07546e-01
560, 9.06992e-01
580, 9.34180e-01
600, 9.06295e-01
620, 9.05643e-01
640, 9.05243e-01
660, 9.04776e-01
680, 9.04388e-01
700, 9.04032e-01
720, 9.12879e-01
740, 9.04290e-01
760, 9.03495e-01
780, 9.03168e-01
800, 9.02900e-01
820, 9.02656e-01
840, 9.02426e-01
860, 9.02235e-01
880, 9.02142e-01
900, 9.02448e-01
920, 9.02044e-01
940, 9.01726e-01
960, 9.01562e-01
980, 9.01404e-01
1000, 9.01251e-01
1020, 9.01104e-01
1040, 9.00967e-01
1060, 9.23807e-01
1080, 9.03771e-01
1100, 9.01187e-01
1120, 9.00777e-01
1140, 9.00613e-01
1160, 9.00503e-01
1180, 9.00396e-01
1200, 9.00292e-01
1220, 9.00189e-01
1240, 9.00089e-01
1260, 9.04391e-01
1280, 9.03790e-01
1300, 9.00097e-01
1320, 9.00009e-01
1340, 8.99875e-01
1360, 8.99792e-01
1380, 8.99707e-01
1400, 8.99624e-01
1420, 8.99541e-01
1440, 8.99460e-01
1460, 8.99378e-01
1480, 8.99303e-01
1500, 9.31087e-01
1520, 9.03912e-01
1540, 8.99408e-01
1560, 8.99234e-01
1580, 8.99164e-01
1600, 8.99085e-01
1620, 8.99011e-01
1640, 8.98936e-01
1660, 8.98862e-01
1680, 8.98786e-01
1700, 8.98710e-01
1720, 8.98633e-01
1740, 8.98578e-01
1760, 8.99579e-01
1780, 9.02128e-01
1800, 8.98598e-01
1820, 8.98541e-01
1840, 8.98441e-01
1860, 8.98362e-01
1880, 8.98286e-01
1900, 8.98210e-01
1920, 8.98132e-01
1940, 8.98053e-01
1960, 8.97973e-01
1980, 8.97892e-01
2000, 8.97811e-01
Total running time of the script: (0 minutes 41.511 seconds)