Optimize the shape of volumetric mesh

graphlow is a powerful tool for shape optimization, leveraging differentiable tensor-based geometric computations.

This example shows how to minimize surface area while constraining volume changes and vertex deformation.

../_images/sphx_glr_shape_optimization_001.gif

Import necessary modules including graphlow.

import itertools

import numpy as np
import pyvista as pv
import torch

import graphlow

Prepare a volumetric mesh

First, we define a function to generate grid data as the example mesh. If you wish to use your own mesh, you can skip this step.

def generate_grid(ni: int, nj: int, nk: int) -> pv.UnstructuredGrid:
    n_cells = (ni - 1) * (nj - 1) * (nk - 1)

    x = np.arange(ni, dtype=np.float32)
    y = np.arange(nj, dtype=np.float32)
    z = np.arange(nk, dtype=np.float32)
    X, Y, Z = np.meshgrid(x, y, z, indexing="ij")
    points = np.array([X.ravel(), Y.ravel(), Z.ravel()]).T
    indices = np.arange(ni * nj * nk).reshape(ni, nj, nk)

    hex_lists = []

    for k in range(nk - 1):
        for j in range(nj - 1):
            for i in range(ni - 1):
                # hex
                v0 = indices[i, j, k]
                v1 = indices[i + 1, j, k]
                v2 = indices[i + 1, j + 1, k]
                v3 = indices[i, j + 1, k]
                v4 = indices[i, j, k + 1]
                v5 = indices[i + 1, j, k + 1]
                v6 = indices[i + 1, j + 1, k + 1]
                v7 = indices[i, j + 1, k + 1]
                hex_lists.append([8, v0, v1, v2, v3, v4, v5, v6, v7])

    hex_cells = list(itertools.chain.from_iterable(hex_lists))
    celltypes = np.full((n_cells), pv.CellType.HEXAHEDRON)
    return pv.UnstructuredGrid(hex_cells, celltypes, points)

Define the cost function

To minimize surface area with constraints, we defined the cost function as follows:

weight_deformation_constraint = 1.0
weight_volume_constraint = 10.0


def cost_function(
    mesh: graphlow.GraphlowMesh,
    init_total_volume: float,
    init_surface_total_area: float,
) -> torch.Tensor:
    deformation = mesh.dict_point_tensor["deformation"]
    surface = mesh.extract_surface(pass_point_data=True)

    volumes = mesh.compute_volumes()
    areas = surface.compute_areas()

    total_volume = torch.sum(volumes)
    total_area = torch.sum(areas)

    if torch.any(volumes < 1e-3 * init_total_volume / mesh.n_cells):
        return None

    cost_area = total_area / init_surface_total_area
    volume_constraint = (
        (total_volume - init_total_volume) / init_total_volume
    ) ** 2
    deformation_constraint = torch.mean(deformation * deformation)
    return (
        cost_area
        + weight_volume_constraint * volume_constraint
        + weight_deformation_constraint * deformation_constraint
    )

Visualize

Following code generate gif plotter to visualize the result.

def create_gif_plotter(mesh: pv.UnstructuredGrid) -> pv.Plotter:
    plotter = pv.Plotter(window_size=[800, 600])
    init_mesh = mesh.copy()
    plotter.add_mesh(init_mesh, show_edges=True, color="white", opacity=0.1)
    plotter.add_mesh(
        mesh, show_edges=True, lighting=False, cmap="turbo", opacity=0.8
    )

    plotter.open_gif("shape_optimization_result.gif")
    plotter.show_bounds(init_mesh, location="outer")
    plotter.camera_position = "iso"
    return plotter

Once the cost function and the visualizing function are determined, the remaining step is to write the optimization code that updates the points and deformations

This is the example of the optimization code.

def optimize_shape(input_mesh: pv.UnstructuredGrid):
    # Optimization setting
    n_optimization = 2000
    print_period = int(n_optimization / 100)
    n_hidden = 64
    deformation_factor = 1.0
    lr = 1e-2
    output_activation = torch.nn.Identity()

    # Initialize
    input_mesh.points = input_mesh.points - np.mean(
        input_mesh.points, axis=0, keepdims=True
    )  # Center mesh position

    mesh = graphlow.GraphlowMesh(input_mesh)

    init_volumes = mesh.compute_volumes().clone()
    init_total_volume = torch.sum(init_volumes)
    init_points = mesh.points.clone()

    init_surface = mesh.extract_surface()
    init_surface_areas = init_surface.compute_areas().clone()
    init_surface_total_area = torch.sum(init_surface_areas)

    w1 = torch.nn.Parameter(torch.randn(3, n_hidden) / n_hidden**0.5)
    w2 = torch.nn.Parameter(torch.randn(n_hidden, 3) / n_hidden**0.5)
    params = [w1, w2]
    optimizer = torch.optim.Adam(params, lr=lr)

    def compute_deformation(points: torch.Tensor) -> torch.Tensor:
        hidden = torch.tanh(torch.einsum("np,pq->nq", points, w1))
        deformation = output_activation(torch.einsum("np,pq->nq", hidden, w2))
        return deformation_factor * deformation

    deformation = compute_deformation(init_points)
    mesh.dict_point_tensor.update({"deformation": deformation}, overwrite=True)

    mesh.copy_features_to_pyvista(overwrite=True)
    mesh.pvmesh.points = mesh.points.detach().numpy()
    plotter = create_gif_plotter(mesh.pvmesh)
    plotter.write_frame()

    # Optimization loop
    print(f"\ninitial volume: {torch.sum(mesh.compute_volumes()):.5f}")
    print("     i,        cost")
    for i in range(1, n_optimization + 1):
        optimizer.zero_grad()

        deformation = compute_deformation(init_points)
        deformed_points = init_points + deformation

        mesh.dict_point_tensor.update(
            {"deformation": deformation}, overwrite=True
        )
        mesh.dict_point_tensor.update(
            {"points": deformed_points}, overwrite=True
        )
        cost = cost_function(mesh, init_total_volume, init_surface_total_area)
        if cost is None:
            deformation_factor = deformation_factor * 0.9
            print(f"update deformation_factor: {deformation_factor}")
            continue

        if i % print_period == 0:
            print(f"{i:6d}, {cost:.5e}")
            mesh.copy_features_to_pyvista(overwrite=True)
            mesh.pvmesh.points = mesh.points.detach().numpy()
            plotter.write_frame()

        cost.backward()
        optimizer.step()
    plotter.close()

Run the optimization

Finally, we define the main function to run the optimization.

def main():
    mesh = generate_grid(10, 10, 10)
    optimize_shape(mesh)


if __name__ == "__main__":
    main()
shape optimization
initial volume: 729.00000
     i,        cost
    20, 9.97958e-01
    40, 9.60497e-01
    60, 9.44003e-01
    80, 9.32981e-01
   100, 9.26768e-01
   120, 9.23295e-01
   140, 9.21050e-01
   160, 9.19457e-01
   180, 9.18257e-01
   200, 9.17275e-01
   220, 9.16411e-01
   240, 9.15622e-01
   260, 9.14886e-01
   280, 9.14194e-01
   300, 9.13542e-01
   320, 9.17020e-01
   340, 9.12734e-01
   360, 9.12301e-01
   380, 9.11764e-01
   400, 9.11228e-01
   420, 9.10658e-01
   440, 9.10567e-01
   460, 9.13917e-01
   480, 9.09154e-01
   500, 9.08657e-01
   520, 9.08090e-01
   540, 9.07546e-01
   560, 9.06992e-01
   580, 9.34180e-01
   600, 9.06295e-01
   620, 9.05643e-01
   640, 9.05243e-01
   660, 9.04776e-01
   680, 9.04388e-01
   700, 9.04032e-01
   720, 9.12879e-01
   740, 9.04290e-01
   760, 9.03495e-01
   780, 9.03168e-01
   800, 9.02900e-01
   820, 9.02656e-01
   840, 9.02426e-01
   860, 9.02235e-01
   880, 9.02142e-01
   900, 9.02448e-01
   920, 9.02044e-01
   940, 9.01726e-01
   960, 9.01562e-01
   980, 9.01404e-01
  1000, 9.01251e-01
  1020, 9.01104e-01
  1040, 9.00967e-01
  1060, 9.23807e-01
  1080, 9.03771e-01
  1100, 9.01187e-01
  1120, 9.00777e-01
  1140, 9.00613e-01
  1160, 9.00503e-01
  1180, 9.00396e-01
  1200, 9.00292e-01
  1220, 9.00189e-01
  1240, 9.00089e-01
  1260, 9.04391e-01
  1280, 9.03790e-01
  1300, 9.00097e-01
  1320, 9.00009e-01
  1340, 8.99875e-01
  1360, 8.99792e-01
  1380, 8.99707e-01
  1400, 8.99624e-01
  1420, 8.99541e-01
  1440, 8.99460e-01
  1460, 8.99378e-01
  1480, 8.99303e-01
  1500, 9.31087e-01
  1520, 9.03912e-01
  1540, 8.99408e-01
  1560, 8.99234e-01
  1580, 8.99164e-01
  1600, 8.99085e-01
  1620, 8.99011e-01
  1640, 8.98936e-01
  1660, 8.98862e-01
  1680, 8.98786e-01
  1700, 8.98710e-01
  1720, 8.98633e-01
  1740, 8.98578e-01
  1760, 8.99579e-01
  1780, 9.02128e-01
  1800, 8.98598e-01
  1820, 8.98541e-01
  1840, 8.98441e-01
  1860, 8.98362e-01
  1880, 8.98286e-01
  1900, 8.98210e-01
  1920, 8.98132e-01
  1940, 8.98053e-01
  1960, 8.97973e-01
  1980, 8.97892e-01
  2000, 8.97811e-01

Total running time of the script: (0 minutes 41.511 seconds)

Gallery generated by Sphinx-Gallery