Hooks en PyTorch

En programación, nos referimos a un hook como el conjunto de técnicas que modifican o aumentan el comportamiento de un programa ante un evento. Esto suele usarse para depurar un programa o ampliar su funcionalidad.

En PyTorch, un hook se puede registrar para el objeto tensor o para el objeto nn.module y los eventos que los activan son el forward o el backward pass del objeto. Es decir, el hook se ejecutará cuando nuestro modelo en PyTorch esté implementando el grafo computacional (forward pass) o cuando se estén calculando los gradientes (backward pass).

Los hooks nos permitirán actuar sobre la entrada, salida o los atributos de un módulo y los atributos de un tensor.

Para implementar un hook, primero tendremos que definirlos, indicando las acciones que realiza sobre el módulo o tensor.

from torch import nn, Tensor

def module_hook(module: nn.Module, input: Tensor, output: Tensor):
    # Hook para nn.Module
    # Acciones
    
def tensor_hook(grad: Tensor):
    # Hook para tensor. Sólo en el backward pass
    # Acciones

Después, tendremos que registrar el hook en el módulo con «torch.nn.modules.module.register_forward_hook(module_hook)» o en el tensor con «tensor.register_hook(tensor_hook)».

A continuación vamos a ver un ejemplo de creación de un hook en un módulo. Vamos a crear un modelo de red neuronal con dos capas, definimos un hook para inspeccionar la entrada y la salida de la capa y lo registramos en la segunda capa del modelo.

import torch
import torch.nn as nn

class Net(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(Net, self).__init__()                    
        self.fc1 = nn.Linear(input_size, hidden_size)  
        self.relu = nn.ReLU()                          
        self.fc2 = nn.Linear(hidden_size, num_classes)

    
    def forward(self, x):                              
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        return out

#Creamos una instancia del modelo
modelo1=Net(100,200,10)

#Definimos la función hook
def module_hook(self, input, output):

    print('Dentro de ' + self.__class__.__name__ + ' forward')
    print('')
    print('input: ', type(input))
    print('input[0]: ', type(input[0]))
    print('output: ', type(output))
    print('')
    print('input size:', input[0].size())
    print('output size:', output.data.size())
    print('output norm:', output.data.norm())

#Registamos el hook en la segunda capa
modelo1.fc2.register_forward_hook(module_hook)
<torch.utils.hooks.RemovableHandle at 0x7f9fef6a76d0>
#Calculamos la salida con una entrada ejemplo. Forward pass
x=torch.rand(100)
output=modelo1(x)
Dentro de Linear forward

input:  <class 'tuple'>
input[0]:  <class 'torch.Tensor'>
output:  <class 'torch.Tensor'>

input size: torch.Size([200])
output size: torch.Size([10])
output norm: tensor(0.2997)

Vemos como cuando se ejecuta el forward pass de nuestro modelo con una entrada ejemplo, se ejecuta el hook registrado en la segunda capa y muestra para esta capa los tipos y tamaños de los tensores entrada y salida y la norma del tensor salida.

Deja una respuesta

Tu dirección de correo electrónico no será publicada.

Orgullosamente ofrecido por WordPress | Tema: Baskerville 2 por Anders Noren.

Subir ↑

A %d blogueros les gusta esto: