Install Theory Guide User Guide API Examples Developer's Guide
Github
  • Installing NeuralOperator
  • Theory Guide
  • User Guide
  • API reference
  • Examples
    • Data
    • Data Generation
    • Layers
    • Models
    • Training and Meta-Algorithms
  • Development guide

Note

Go to the end to download the full example code.

Visualizing neighbor search

Understanding the spatial search used to compute neighborhoods for the Graph Neural Operator.

import random
import matplotlib.pyplot as plt
import torch

from neuralop.layers.gno_block import GNOBlock
from neuralop.layers.neighbor_search import native_neighbor_search
from neuralop.layers.embeddings import regular_grid_2d

device = 'cpu'

Basic logic

Many problems involve data collected over irregular point clouds. The Graph Neural Operator (GNO) is a simple neural operator architecture that learns a map between functions evaluated on (potentially different) arbitrary point clouds. For a set of input coordinates Y, an input function f evaluated at all y ∈ Y, and a set of output coordinates X, our goal is to map to a new function g evaluated at all x ∈ X. To do so, the GNO computes the Nyström approximation of a continuous kernel integral by summing the values of f across a neighborhood of each point x drawn from the input coordinates Y (with an optional kernel k): int_{N_r(x)} f(y) k(x,y) dy

The first step of this process is a neighbor search: %%

input_coords = torch.stack(regular_grid_2d(spatial_dims=(8,8))).permute(1,2,0).view(-1,2) #reshape into (64, 2)
output_queries = torch.rand([50, 2])

plt.scatter(input_coords[:, 0], input_coords[:, 1], color='orange', label="Input coordinates")
plt.scatter(output_queries[:, 0], output_queries[:, 1], color='blue', label="Output queries")
plt.legend()
plot neighbor search
<matplotlib.legend.Legend object at 0x7f7690f24550>

Visualizing neighborhoods

Now, let’s select a point in the output and visualize its neighbors.

query_index = 25
query_point = output_queries[query_index]
# Let's search, assuming a radius of 0.25. Note that this is quite high for the density of our data.
# in practice we tend to use values that find on the order of 10 neighbors.
nbr_data = native_neighbor_search(data=input_coords, queries=query_point.unsqueeze(0), radius=0.25)
fig, ax = plt.subplots()
neighbors = input_coords[nbr_data['neighbors_index']]
ax.scatter(input_coords[:, 0], input_coords[:, 1])
ax.scatter(query_point[0], query_point[1])
ax.scatter(neighbors[:, 0], neighbors[:, 1], label="neighbors of x")
c = plt.Circle(query_point, radius=0.25, fill=False)
ax.add_patch(c)
ax.legend()
ax.set_xlim(0,1)
ax.set_ylim(0,1)
plot neighbor search
(0.0, 1.0)

Total running time of the script: (0 minutes 0.133 seconds)

Download Jupyter notebook: plot_neighbor_search.ipynb

Download Python source code: plot_neighbor_search.py

Download zipped: plot_neighbor_search.zip

Gallery generated by Sphinx-Gallery

Layers Grid embeddings
© Copyright 2025, Jean Kossaifi, David Pitt, Nikola Kovachki, Zongyi Li and Anima Anandkumar.

On this page

  • Visualizing neighbor search
    • Basic logic
    • Visualizing neighborhoods