top of page
Search
  • Writer's picturesiria sadeddin

Loss Focal: Una solución para el desbalance de datos

Hola de nuevo! 🤩


Mientras estudiaba métodos de Computer Vision para tratar el desbalance de datos me encontré con este método para el cálculo del Loss que me pareció interesante y quiero compartir con ustedes hoy. Este método se llama Focal Loss, y se basa en la reconstrucción del Loss de Crossentropy Binario de manera que los ejemplos "fáciles" de clasificar contribuyen menos al Loss, así el modelo puede "concentrarse" en aprender los ejemplos mas "difíciles" de clasificar. Esto es especialmente útil cuando tenemos un conjunto de datos muy desbalanceado.



Imagen tomada de [1]


El Loss usual para la Crossentropy Binario es el siguiente:

Donde 'p' es la probabilidad de pertenecer a la clase (0, 1) y 'y' el valor verdadero. Por simplicidad definiremos:



Clases pesadas

 

Un método comúnmente usado para tratar el desbalance de clases añadir pesos a las contribuciones de cada clase en la función de Loss. Por ejemplo, un set de datos donde la relación entre clases binarias [1,0] es de 1:10 tendrá pesos w_1=0.9 y w_0=0.1. De esta manera la contribución total de cada una de las clases al Loss será proporcional, sin importar la cantidad de ejemplos que tenga cada clase en el conjunto de datos.


Definimos la función de Loss Pesada como sigue:


Donde α_t es el peso para cada clase.


En python podemos escribir el Loss pesado de la siguiente manera:


import tensorflow as tf
from keras import backend as K 
     
def weighted_binary_crossentropy(weight1=freq_neg, weight0=freq_pos):
    def loss(y_true,y_pred):
        y_true = tf.cast(y_true, tf.float32)
        y_pred = tf.cast(y_pred, tf.float32)
        logloss = - (y_true * K.log(y_pred + K.epsilon()) * weight1 +
        (1 - y_true) * K.log(1 - y_pred + K.epsilon()) * weight0 )
        return K.mean( logloss, axis=-1)
     return loss

Loss Focal

 

A pesar de que los pesos solucionan el problema del desbalance de los datos, esto no soluciona el problema de los ejemplos fáciles o difíciles de clasificar. Cuando un ejemplo es difícil de clasificar, este contribuye mas al Loss, mientras que los ejemplos fáciles de clasificar contribuyen menos al Loss, lo que queremos es minimizar las contribuciones de los ejemplos fáciles de clasificar, para que durante el entrenamiento, el modelo ponga mas atención a los ejemplos difíciles. Esto lo lograremos agregando un factor de modulación dinámico a la ecuación de Loss:



Vemos que cuando ɣ=0 retornamos a la función de Loss original. Cuando un ejemplo es fácil de clasificar p_t ->1 y el factor de modulación tiende a cero, mientras que cuando el ejemplo es muy difícil de clasificar p_t->0 y el Loss tiene factor de modulación ->1. Se ha encontrado que ɣ=2 funciona bien en los experimentos [1].


Supongamos que elegimos ɣ=2, si p_t=0.9 el Loss correspondiente será 100 veces menor que en el Loss original, mientras que si p_t=0.968 el Loss será 1000 veces menor. Para los casos donde la clasificación es difícil, p_t=0.5 por ejemplo, el Loss se verá reducido solo en un factor de 4. En el gráfico al principio de este post podemos ver como se comporta el Loss en función de la probabilidad de pertenecer a la clase verdadera.


Podemos usar un enfoque mixto al incluir el factor de modulación y los pesos de las clases en la función:





Un ejemplo del código en python para esta función de Loss se presenta a continuación:


def binary_focal_loss(gamma=2., alpha=pos_freq):
 """
    Binary form of focal loss.
      FL(p_t) = -alpha * (1 - p_t)**gamma * log(p_t)
      where p = sigmoid(x), p_t = p or 1 - p depending on if the label is 1 or 0, respectively.
    References:
        https://arxiv.org/pdf/1708.02002.pdf
    Usage:
     model.compile(loss=[binary_focal_loss(alpha=.25, gamma=2)], metrics=["accuracy"], optimizer=adam)
    """
 def binary_focal_loss_fixed(y_truey_pred):
 """
        :param y_true: A tensor of the same shape as `y_pred`
        :param y_pred:  A tensor resulting from a sigmoid
        :return: Output tensor.
        """
        pt_1 = tf.where(tf.equal(y_true, 1), y_pred, tf.ones_like(y_pred))
        pt_0 = tf.where(tf.equal(y_true, 0), y_pred, tf.zeros_like(y_pred))

        epsilon = K.epsilon()
 # clip to prevent NaN's and Inf's
        pt_1 = K.clip(pt_1, epsilon, 1. - epsilon)
        pt_0 = K.clip(pt_0, epsilon, 1. - epsilon)

 return -K.sum(alpha * K.pow(1. - pt_1, gamma) * K.log(pt_1)) -K.sum((1 - alpha) * K.pow(pt_0, gamma) * K.log(1. - pt_0))

 return binary_focal_loss_fixed
 
 

Les invito que intenten usar estos métodos en sus entrenamiento, espero que sea de ayuda.

Hasta pronto! 😁





Referencias

 

Focal Loss for Dense Object Detection, Tsung-Yi Lin, Priya Goyal, Ross Girshick, Kaiming He, Piotr Dollar. Facebook AI Research (FAIR). (7 Feb 2018) https://arxiv.org/pdf/1708.02002.pdf


287 views0 comments
Post: Blog2_Post
bottom of page