Deep Learning
Gradient tree varies with loss expression
ddokkddokk
2025. 7. 16. 15:24
반응형
import torch
import torch.nn as nn
from torchviz import make_dot
# 1. Define a simple model or computation
class NE(nn.Module):
def __init__(self):
super(NE, self).__init__()
self.fc1_ne = nn.Linear(10, 10)
self.relu = nn.ReLU()
self.fc2_ne = nn.Linear(10, 10)
def forward(self, x):
x = self.fc1_ne(x)
x = self.relu(x)
x = self.fc2_ne(x)
return x
class DEC(nn.Module):
def __init__(self):
super(DEC, self).__init__()
self.fc1_dec = nn.Linear(10, 10)
self.relu = nn.ReLU()
self.fc2_dec = nn.Linear(10, 10)
def forward(self, x):
x = self.fc1_dec(x)
x = self.relu(x)
x = self.fc2_dec(x)
return x
# Instantiate the models
ne_model = NE() # Renamed to avoid conflict with the class name
dec_model = DEC() # Renamed to avoid conflict with the class name
# Create a dummy input tensor
Bo = torch.randn(1, 10, requires_grad=True)
B = torch.randn(1, 10, requires_grad=True)
No = B - Bo
Oo = torch.zeros(1, 10)
# Perform a forward pass
Nest = ne_model(B)
Best = B - Nest
Oest = dec_model(Best)
# Loss1 = torch.mean((Nest - No).pow(2)) # eqn 4
Loss1 = torch.mean((Bo - Best).pow(2)) # eqn 5
Loss2 = torch.mean((Oest - Oo).pow(2))
Loss = Loss1 + Loss2
# --- Visualization Part ---
# Collect all tensors that we want to appear as "leaf" nodes in the graph
# These are typically the inputs that require gradients and the model parameters.
params = dict(list(ne_model.named_parameters()) +
list(dec_model.named_parameters()) +
[('Bo', Bo), ('B', B)]) # Include inputs if you want to see gradient flow to them
# Generate the computational graph
# make_dot takes the output tensor (Loss) and an optional dictionary of parameters
graph = make_dot(Loss, params=params)
# Display the graph (e.g., in a Jupyter Notebook or IPython environment)
graph
# To save the graph to a file (e.g., PDF, PNG)
graph.render("gradient_tree_combined_loss_eqn5", format="png", view=False)
print("Gradient tree saved as gradient_tree_combined_loss.png")