Como hemos comentado en otros posts, en los modelos de machine learning el objetivo es encontrar los parámetros del modelo que minimizan una función de coste.
En redes neuronales se usa el algoritmo de descenso por gradiente, que va realizando sobre los parámetros iteraciones proporcionales al valor negativo del gradiente en el punto actual. El algoritmo usa backpropagation para calcular el valor del gradiente.
Para los parámetros de la capa l se repetiría la siguiente iteración hasta que los parámetros converjan (o hasta un máximo de iteraciones prefijado):
¿Pero qué ocurre cuando la red tiene muchas capas?
En esta red neuronal simple con n capas ocultas y con una neurona por capa, la salida de cada neurona se calcula multiplicando la salida de la neurona previa por el parámetro Wn y aplicando la función de activación f(). La función de coste J al final de la red devuelve el error del sistema y se utiliza para modificar el resto de parámetros de la red a través del descenso por gradiente.
La derivada respecto al parámetro de la primera capa usando la regla de la cadena queda:
Si calculamos el valor de un ejemplo de las derivadas entre capas:
Donde
Si la función de activación elegida f(z) es la función sigmoide, una de las más comunes, debido a que su derivada siempre está acotada entre 0 y 0.25, cuando tenemos una red con muchas capas el valor de gradiente cada vez es más cercano a 0 ya que estamos multiplicando muchas veces un valor pequeño. Debido a este problema, las primeras capas de una red neuronal son las más lentas y difíciles de entrenar ya que el valor del gradiente que se usa para actualizarlas en cada iteración del entrenamiento es muy pequeño. Y esto causa otro problema, que si las primeras capas no están bien entrenadas, el problema se arrastra a las capas posteriores.
El problema es similar en las redes neuronales recurrentes (RNN). Las RNN modelan datos en los que es importante la estructura temporal, como frases con palabras, y también se entrenan usando backpropagation. Cada intervalo de tiempo sería como una capa y por lo tanto una RNN sería equivalente a una red con tantas capas como intervalos de tiempo y sufrirían el mismo problema de desvanecimiento del gradiente.
En la imagen se puede ver como la sensibilidad a los valores de la entrada en una RNN decae con el tiempo. De esta forma es muy difícil que la red recuerde dependencias temporales largas.
¿Qué soluciones hay al problema del desvanecimiento del gradiente?
Como hemos visto, la derivada de la función sigmoide es menor o igual que 0.25, pero si cogemos como función de activación la función ReLU, cuya derivada es 1 por encima de 0, podríamos obtener mejores soluciones.
Para el caso de las RNN, la solución pasa por usar LSTMs (Long Short-Term Memory Networks), donde cada nodo o neurona es una célula de memoria. De esta forma la red es capaz de retener información de entradas anteriores en el tiempo y tener en cuenta dependencias temporales largas.