Auxilliary Networks Showcase
Fitting shapes using NIGnets powered by Auxilliary Networks
We now fit Injective Networks powered by Auxilliary Networks to some target shapes to get a sense of their representation power and shortcomings.
# Basic imports
import torch
from torch import nn
import geosimilarity as gs
from NIGnets import NIGnet
from NIGnets.monotonic_nets import SmoothMinMaxNet
from assets.utils import automate_training, plot_curves
We will use the following network architecture for PreAux nets in this showcase. Users need to define their own PreAux net architectures similarly making sure that the conditions on the output are met Paragraph.
class PreAuxNet(nn.Module):
def __init__(self, layer_count, hidden_dim):
super().__init__()
# Pre-Auxilliary net needs closed transform to get same r at theta = 0, 2pi
self.closed_transform = lambda t: torch.hstack([
torch.cos(2 * torch.pi * t),
torch.sin(2 * torch.pi * t)
])
layers = [nn.Linear(2, hidden_dim), nn.BatchNorm1d(hidden_dim), nn.PReLU()]
for i in range(layer_count):
layers.append(nn.Linear(hidden_dim, hidden_dim))
layers.append(nn.BatchNorm1d(hidden_dim))
layers.append(nn.PReLU())
layers.append(nn.Linear(hidden_dim, 1))
layers.append(nn.ReLU())
self.forward_stack = nn.Sequential(*layers)
def forward(self, t):
unit_circle = self.closed_transform(t) # Rows are cos(theta), sin(theta)
r = self.forward_stack(unit_circle)
x = r * unit_circle # Each row is now r*cos(theta), r*sin(theta)
return x
Pre-Auxilliary Networks¶
Circle and Square¶
from assets.shapes import circle, square
# Generate target curve points
num_pts = 1000
t = torch.linspace(0, 1, num_pts).reshape(-1, 1)
Xt_circle = circle(num_pts)
Xt_square = square(num_pts)
# Initialize networks to learn the target shapes and train
preaux_net = PreAuxNet(layer_count = 2, hidden_dim = 5)
circle_net = NIGnet(layer_count = 3, preaux_net = preaux_net, act_fn = nn.SELU)
square_net = NIGnet(layer_count = 3, preaux_net = preaux_net, act_fn = nn.SELU)
print('Training Circle Net:')
automate_training(
model = circle_net, loss_fn = gs.MSELoss(), X_train = t, Y_train = Xt_circle,
learning_rate = 0.1, epochs = 1000, print_cost_every = 200
)
print('Training Square Net:')
automate_training(
model = square_net, loss_fn = gs.MSELoss(), X_train = t, Y_train = Xt_square,
learning_rate = 0.1, epochs = 1000, print_cost_every = 200
)
# Get final curve represented by the networks
Xc_circle = circle_net(t)
Xc_square = square_net(t)
# Plot the curves
plot_curves(Xc_circle, Xt_circle)
plot_curves(Xc_square, Xt_square)
Training Circle Net:
Epoch: [ 1/1000]. Loss: 0.651552
Epoch: [ 200/1000]. Loss: 0.000073
Epoch: [ 400/1000]. Loss: 0.000042
Epoch: [ 600/1000]. Loss: 0.000045
Epoch: [ 800/1000]. Loss: 0.000041
Epoch: [1000/1000]. Loss: 0.000022
Training Square Net:
Epoch: [ 1/1000]. Loss: 0.843364
Epoch: [ 200/1000]. Loss: 0.000456
Epoch: [ 400/1000]. Loss: 0.000986
Epoch: [ 600/1000]. Loss: 0.000112
Epoch: [ 800/1000]. Loss: 0.000343
Epoch: [1000/1000]. Loss: 0.000103
data:image/s3,"s3://crabby-images/67cd9/67cd90bf258c65dcaed49c34d87a8587dba99f65" alt="<Figure size 640x480 with 1 Axes>"
data:image/s3,"s3://crabby-images/30466/304663f3c5cd45eee56af42118c1905ea64d4ed7" alt="<Figure size 640x480 with 1 Axes>"
Stanford Bunny¶
from assets.shapes import stanford_bunny
# Generate target curve points
num_pts = 1000
t = torch.linspace(0, 1, num_pts).reshape(-1, 1)
Xt = stanford_bunny(num_pts)
preaux_net = PreAuxNet(layer_count = 2, hidden_dim = 10)
nig_net = NIGnet(layer_count = 5, preaux_net = preaux_net, act_fn = nn.SELU)
automate_training(
model = nig_net, loss_fn = gs.MSELoss(), X_train = t, Y_train = Xt,
learning_rate = 0.1, epochs = 10000, print_cost_every = 2000
)
Xc = nig_net(t)
plot_curves(Xc, Xt)
Epoch: [ 1/10000]. Loss: 0.440139
Epoch: [ 2000/10000]. Loss: 0.000972
Epoch: [ 4000/10000]. Loss: 0.000303
Epoch: [ 6000/10000]. Loss: 0.000278
Epoch: [ 8000/10000]. Loss: 0.000266
Epoch: [10000/10000]. Loss: 0.000251
data:image/s3,"s3://crabby-images/de92a/de92af1c8f42ba06a033c3329f5c3cde81fad184" alt="<Figure size 640x480 with 1 Axes>"
Heart¶
from assets.shapes import heart
# Generate target curve points
num_pts = 1000
t = torch.linspace(0, 1, num_pts).reshape(-1, 1)
Xt = heart(num_pts)
preaux_net = PreAuxNet(layer_count = 2, hidden_dim = 10)
nig_net = NIGnet(layer_count = 5, preaux_net = preaux_net, act_fn = nn.SELU)
automate_training(
model = nig_net, loss_fn = gs.MSELoss(), X_train = t, Y_train = Xt,
learning_rate = 0.1, epochs = 10000, print_cost_every = 2000
)
Xc = nig_net(t)
plot_curves(Xc, Xt)
Epoch: [ 1/10000]. Loss: 0.561411
Epoch: [ 2000/10000]. Loss: 0.000165
Epoch: [ 4000/10000]. Loss: 0.000013
Epoch: [ 6000/10000]. Loss: 0.000012
Epoch: [ 8000/10000]. Loss: 0.000011
Epoch: [10000/10000]. Loss: 0.000011
data:image/s3,"s3://crabby-images/f1439/f1439e61cfd6aa12c64490104f4c070639889c4e" alt="<Figure size 640x480 with 1 Axes>"
Hand¶
from assets.shapes import hand
# Generate target curve points
num_pts = 1000
t = torch.linspace(0, 1, num_pts).reshape(-1, 1)
Xt = hand(num_pts)
preaux_net = PreAuxNet(layer_count = 2, hidden_dim = 10)
nig_net = NIGnet(layer_count = 5, preaux_net = preaux_net, act_fn = nn.SELU)
automate_training(
model = nig_net, loss_fn = gs.MSELoss(), X_train = t, Y_train = Xt,
learning_rate = 0.1, epochs = 10000, print_cost_every = 2000
)
Xc = nig_net(t)
plot_curves(Xc, Xt)
Epoch: [ 1/10000]. Loss: 0.692996
Epoch: [ 2000/10000]. Loss: 0.002578
Epoch: [ 4000/10000]. Loss: 0.001467
Epoch: [ 6000/10000]. Loss: 0.001278
Epoch: [ 8000/10000]. Loss: 0.000983
Epoch: [10000/10000]. Loss: 0.000533
data:image/s3,"s3://crabby-images/faf04/faf04294124d70ba92f530d0a7a0476e2d476dcf" alt="<Figure size 640x480 with 1 Axes>"
Airplane¶
from assets.shapes import airplane
# Generate target curve points
num_pts = 1000
t = torch.linspace(0, 1, num_pts).reshape(-1, 1)
Xt = airplane(num_pts)
preaux_net = PreAuxNet(layer_count = 2, hidden_dim = 25)
nig_net = NIGnet(layer_count = 5, preaux_net = preaux_net, act_fn = nn.SELU)
automate_training(
model = nig_net, loss_fn = gs.MSELoss(), X_train = t, Y_train = Xt,
learning_rate = 0.1, epochs = 10000, print_cost_every = 2000
)
Xc = nig_net(t)
plot_curves(Xc, Xt)
Epoch: [ 1/10000]. Loss: 0.864464
Epoch: [ 2000/10000]. Loss: 0.001621
Epoch: [ 4000/10000]. Loss: 0.001200
Epoch: [ 6000/10000]. Loss: 0.000993
Epoch: [ 8000/10000]. Loss: 0.000815
Epoch: [10000/10000]. Loss: 0.000728
data:image/s3,"s3://crabby-images/249fa/249faba7e0d054339d5eb0d83ab6ccc0c81b1e10" alt="<Figure size 640x480 with 1 Axes>"