En programación, nos referimos a un hook como el conjunto de técnicas que modifican o aumentan el comportamiento de un programa ante un evento. Esto suele usarse para depurar un programa o ampliar su funcionalidad.
En PyTorch, un hook se puede registrar para el objeto tensor o para el objeto nn.module y los eventos que los activan son el forward o el backward pass del objeto. Es decir, el hook se ejecutará cuando nuestro modelo en PyTorch esté implementando el grafo computacional (forward pass) o cuando se estén calculando los gradientes (backward pass).
Los hooks nos permitirán actuar sobre la entrada, salida o los atributos de un módulo y los atributos de un tensor.
Para implementar un hook, primero tendremos que definirlos, indicando las acciones que realiza sobre el módulo o tensor.
from torch import nn, Tensor def module_hook(module: nn.Module, input: Tensor, output: Tensor): # Hook para nn.Module # Acciones def tensor_hook(grad: Tensor): # Hook para tensor. Sólo en el backward pass # Acciones
Después, tendremos que registrar el hook en el módulo con “torch.nn.modules.module.register_forward_hook(module_hook)” o en el tensor con “tensor.register_hook(tensor_hook)”.
A continuación vamos a ver un ejemplo de creación de un hook en un módulo. Vamos a crear un modelo de red neuronal con dos capas, definimos un hook para inspeccionar la entrada y la salida de la capa y lo registramos en la segunda capa del modelo.
import torch import torch.nn as nn class Net(nn.Module): def __init__(self, input_size, hidden_size, num_classes): super(Net, self).__init__() self.fc1 = nn.Linear(input_size, hidden_size) self.relu = nn.ReLU() self.fc2 = nn.Linear(hidden_size, num_classes) def forward(self, x): out = self.fc1(x) out = self.relu(out) out = self.fc2(out) return out #Creamos una instancia del modelo modelo1=Net(100,200,10) #Definimos la función hook def module_hook(self, input, output): print('Dentro de ' + self.__class__.__name__ + ' forward') print('') print('input: ', type(input)) print('input[0]: ', type(input[0])) print('output: ', type(output)) print('') print('input size:', input[0].size()) print('output size:', output.data.size()) print('output norm:', output.data.norm()) #Registamos el hook en la segunda capa modelo1.fc2.register_forward_hook(module_hook)
<torch.utils.hooks.RemovableHandle at 0x7f9fef6a76d0>
#Calculamos la salida con una entrada ejemplo. Forward pass x=torch.rand(100) output=modelo1(x)
Dentro de Linear forward input: <class 'tuple'> input[0]: <class 'torch.Tensor'> output: <class 'torch.Tensor'> input size: torch.Size([200]) output size: torch.Size([10]) output norm: tensor(0.2997)
Vemos como cuando se ejecuta el forward pass de nuestro modelo con una entrada ejemplo, se ejecuta el hook registrado en la segunda capa y muestra para esta capa los tipos y tamaños de los tensores entrada y salida y la norma del tensor salida.
Deja una respuesta