Deep Learning

Gradient tree varies with loss expression

ddokkddokk 2025. 7. 16. 15:24
반응형

import torch
import torch.nn as nn
from torchviz import make_dot

# 1. Define a simple model or computation
class NE(nn.Module):
    def __init__(self):
        super(NE, self).__init__()
        self.fc1_ne = nn.Linear(10, 10)
        self.relu = nn.ReLU()
        self.fc2_ne = nn.Linear(10, 10)

    def forward(self, x):
        x = self.fc1_ne(x)
        x = self.relu(x)
        x = self.fc2_ne(x)
        return x

class DEC(nn.Module):
    def __init__(self):
        super(DEC, self).__init__()
        self.fc1_dec = nn.Linear(10, 10)
        self.relu = nn.ReLU()
        self.fc2_dec = nn.Linear(10, 10)

    def forward(self, x):
        x = self.fc1_dec(x)
        x = self.relu(x)
        x = self.fc2_dec(x)
        return x

# Instantiate the models
ne_model = NE() # Renamed to avoid conflict with the class name
dec_model = DEC() # Renamed to avoid conflict with the class name

# Create a dummy input tensor
Bo = torch.randn(1, 10, requires_grad=True)
B = torch.randn(1, 10, requires_grad=True)
No = B - Bo
Oo = torch.zeros(1, 10)

# Perform a forward pass
Nest = ne_model(B)
Best = B - Nest
Oest = dec_model(Best)

# Loss1 = torch.mean((Nest - No).pow(2))      # eqn 4
Loss1 = torch.mean((Bo - Best).pow(2))    # eqn 5
Loss2 = torch.mean((Oest - Oo).pow(2))
Loss = Loss1 + Loss2

# --- Visualization Part ---
# Collect all tensors that we want to appear as "leaf" nodes in the graph
# These are typically the inputs that require gradients and the model parameters.
params = dict(list(ne_model.named_parameters()) +
              list(dec_model.named_parameters()) +
              [('Bo', Bo), ('B', B)]) # Include inputs if you want to see gradient flow to them

# Generate the computational graph
# make_dot takes the output tensor (Loss) and an optional dictionary of parameters
graph = make_dot(Loss, params=params)

# Display the graph (e.g., in a Jupyter Notebook or IPython environment)
graph

# To save the graph to a file (e.g., PDF, PNG)
graph.render("gradient_tree_combined_loss_eqn5", format="png", view=False)
print("Gradient tree saved as gradient_tree_combined_loss.png")