Creando nuevas funciones en PyTorch

En este post de nuestro tutorial de deep learning con PyToch vamos a ver como extender PyTorch. Si quieremos implementar un nuevo módulo o función no disponible en las librerías de PyTorch tenemos varias opciones dependiendo del caso:

  • Si queremos añadir primitivas clásicas (if, while,…) en un módulo, simplemente insertaremos las primitivas en el método forward de nuestro modelo.
  • Si la función que queremos desarrollar se puede escribir usando operaciones de PyTorch y entonces autograd es capaz de registrar las operaciones y calcular los gradientes. En este caso crearemos un módulo.
  • Si vamos a usar operaciones no nativas de PyTorch y queremos que sean diferenciables junto al resto del modelo. En este caso crearemos una subclasse Function para implementar la operación.

Añadir primitivas clásicas

Cuando queremos tener condicionantes o sentencias de programación como un «if else», «while»… tenemos que insertarlos en el método forward de nuestro modelo. Por ejemplo, a continuación vemos el método forward de la clase MultiheadAttention que hereda de la clase module. Podemos ver cómo implementa varios condicionales «if else».

def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: Optional[Tensor] = None, need_weights: bool = True, attn_mask: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]:

 
if self.batch_first:
     query, key, value = [x.transpose(1, 0) for x in (query, key, value)]

if not self._qkv_same_embed_dim: 
     attn_output, attn_output_weights = F.multi_head_attention_forward(
                query, key, value, self.embed_dim, self.num_heads,
                self.in_proj_weight, self.in_proj_bias,
                self.bias_k, self.bias_v, self.add_zero_attn,
                self.dropout, self.out_proj.weight, self.out_proj.bias,
                training=self.training,
                key_padding_mask=key_padding_mask, need_weights=need_weights,
                attn_mask=attn_mask, use_separate_proj_weight=True,
                q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
                v_proj_weight=self.v_proj_weight)
else:
     attn_output, attn_output_weights = F.multi_head_attention_forward(
                query, key, value, self.embed_dim, self.num_heads,
                self.in_proj_weight, self.in_proj_bias,
                self.bias_k, self.bias_v, self.add_zero_attn,
                self.dropout, self.out_proj.weight, self.out_proj.bias,
                training=self.training,
                key_padding_mask=key_padding_mask, need_weights=need_weights,
                attn_mask=attn_mask)
if self.batch_first:
     return attn_output.transpose(1, 0), attn_output_weights
else:
     return attn_output, attn_output_weights

Crear un nuevo módulo

En este caso, definimos un nuevo módulo que herede de la clase nn.Module. En el método de inicialización init instanciamos todos los objetos que vayamos a usar y en el método forward definimos el grafo de operaciones de PyTorch que va a ejecutar el módulo. A continuación vemos la estructura del código.

class MyLayer(nn.Module):
    """ Custom layer """
    def __init__(self, size_in, size_out):
        super().__init__()
        # initialize the objects

    def forward(self, x):
        #Define the sequence of PyTorch operations, 
        #including differentiable and non-differentiable operations
        return result

Crear una nueva función

Este caso, descrito aquí, es el más complicado y por ello el último recurso. Es necesario crear una subclase de Function y definir el método forward() y backward(), llamar a los métodos apropiados en el argumento ctx, declarar si la función soporta doble backward y validar que los gradientes son correctos usando gradcheck.

En el forward, recibimos un tensor que contiene la entrada y devolvemos un tensor que contiene la salida. ctx es un objeto de contexto que se puede utilizar para almacenar información para el backward. Puede almacenar en caché objetos arbitrarios para usarlos en el backward utilizando el método ctx.save_for_backward.

En el backward recibimos un tensor que contiene el gradiente de la pérdida con respecto a la salida, y necesitamos calcular el gradiente de la pérdida con respecto a la entrada.

Deja un comentario

Tu dirección de correo electrónico no será publicada. Los campos obligatorios están marcados con *

Orgullosamente ofrecido por WordPress | Tema: Baskerville 2 por Anders Noren.

Subir ↑

A %d blogueros les gusta esto: