Alle prese con PyTorch - Parte 3: Implementare nuovi moduli


PyTorch è un framework di deep learning, sviluppato principalmente dal Facebook AI Research (FAIR) group, che ha guadagnato una enorme popolarità fra gli sviluppatori grazie alla combinazione di semplicità ed efficienza. Questi tutorial sono dedicati ad esplorare la libreria, partendo dai concetti più semplici fino alla definizione di modelli estremamente sofisticati.

In questa terza parte vediamo come implementare nuovi moduli all'interno della libreria con un esempio pratico: Swish, una funzione di attivazione con ottime performance introdotta da Google l'anno scorso.

Questi tutorial sono anche disponibili (parzialmente) in lingua inglese: Fun With PyTorch.

Contenuto di questo tutorial

Finora abbiamo visto come usare gli strumenti ed i modelli già pronti di PyTorch per creare reti neurali ed ottimizzarle sui nostri problemi. Il deep learning, però, si evolve di giorno in giorno: in continuazione vengono proposte nuove idee o varianti di idee note. Per essere in grado di sperimentarle, è utile capire anche come implementare nuovi moduli all'interno di PyTorch.

Per rimanere sul pratico, useremo come caso d'uso Swish, una funzione di attivazione proposta in un articolo ad Ottobre 2017 in alternativa alle più classiche ReLU o tangenti iperboliche, che sembra in grado di ottenere miglioramenti significativi in numerosi problemi.

Curiosità: Swish non è stata progettata 'a mano', ma è il risultato di una ricerca automatica nello spazio di tutte le possibili funzioni di attivazione, tramite un meccanismo di reinforcement learning, per cercare di trovare la 'migliore' funzione di attivazione esistente (più dettagli, ovviamente, nell'articolo originario: Ramachandran et al., Searching for Activation Functions, arXiv:1710.05941, 2017).

Matematicamente, Swish è definita nel seguente modo:

$$ h(s) = s \cdot \sigma(\beta s) \,,$$

dove $s$ è l'input della funzione di attivazione, $\sigma(\cdot)$ è la classica funzione sigmoide, e $\beta$ è un valore costante o appreso a livello del singolo neurone.

Avremo modo di commentare ampiamente le caratteristiche di Swish durante la nostra implementazione. Per rendere tutto il più chiaro possibile, procederemo in tre fasi (più una sezione 'bonus'), costruendo versioni via via più sofisticate della funzione stessa.

Parte 1 di 3: Swish senza parametri

Partiamo dal caso più semplice, con $\beta$ fisso ad 1, che nell'articolo viene chiamata Swish-1 (ed era stata proposta in precedenza sotto il nome di sigmoid-weighted linear unit).

Ricordate dalla seconda parte del tutorial che tutti i componenti di una rete neurale, in PyTorch, sono implementati come estensioni di torch.nn.Module, per permettere di riutilizzarli all'interno di modelli sempre più complessi. In questo senso, implementare Swish-1 è molto simile a quanto visto in precedenza:

import torch
class Swish1(torch.nn.Module):
    def forward(self, input):
        return input * torch.sigmoid(input)

Non c'è bisogno di inizializzare nulla in questo caso, non avendo nessun parametro da selezionare. Questa versione è già funzionale: ad esempio, possiamo reimplementare il modello che avevamo usato su Iris, sostituendo la ReLU con la nostra nuova funzione:

net_sequential = nn.Sequential(
        nn.Linear(4, 10),
        Swish1(),
        nn.Linear(10, 3)
)

Possiamo anche divertirci a graficare la funzione:

# Valori sull'asse x
x = np.linspace(-5.0, 5.0, 1000).reshape(-1, 1)

# Calcola Swish-1 su tutti i valori
swish1 = Swish1()
y = swish1(torch.from_numpy(x))

# Grafica il risultato
plt.plot(x, y.numpy())
Swish-1

Notiamo che questa versione della funzione è molto simile alla ReLU, con la differenza importante che Swish-1 ha un comportamento non-monotono: per attivazioni negative diminuisce prima di risalire verso 0. Il motivo del perché questo migliori le performance non è chiarissimo nemmeno agli autori dell'articolo!

Possiamo anche usare gli strumenti di differenziazione automatica per graficare la derivata della funzione:

x = torch.linspace(-5.0, 5.0, 1000, requires_grad=True)

# Ci sono modi più efficienti! :-)
g = [torch.autograd.grad(swish1(xi), xi) for xi in x]

plt.plot(x.detach().numpy(), g)
Swish-1 (Gradiente)

Fatto tutto questo, passiamo a qualcosa di più interessante.

Parte 2 di 3: Swish con parametro costante

Per la seconda parte dell'implementazione, introduciamo il parametro $\beta$ ma lo lasciamo a scelta dell'utente e non adattabile. Saremmo tentati (seguendo quanto visto prima) di implementare il tutto così:

class ConstantBetaSwish(nn.Module):
    # QUESTA IMPLEMENTAZIONE E' ERRATA

    def __init__(self, beta=2.0):
        super(ConstantBetaSwish, self).__init__()
        self.beta = torch.tensor(beta)

    def forward(self, input):
        return input * torch.sigmoid(input * self.beta)

L'unica differenza è la costante $\beta$, passata come parametro di inizializzazione. Anche se tutto sembra corretto, questa implementazione ha un bug, che spunta fuori se proviamo ad ottenere lo stato di questo modulo (ad esempio per fare checkpointing:

swish2.state_dict() # Vuoto!

Per fare in modo che $\beta$ venga considerato parte integrante dello stato del modulo, è necessario 'registrarlo' in fase di inizializzazione con un metodo apposito, register_buffer:

class ConstantBetaSwish(nn.Module):

    def __init__(self, beta=2.0):
        super(ConstantBetaSwish, self).__init__()
        self.register_buffer('beta', torch.tensor(beta, dtype=torch.float32))

    def forward(self, input):
        return input * torch.sigmoid(input * Variable(self.beta))

    def extra_repr(self):
        return 'beta={}'.format(self.beta)

Ne abbiamo anche approfittato per aggiungere un nuovo metodo, extra_repr, che permette di stampare a schermo informazioni utili sul modulo, ad esempio:

net = nn.Sequential(
        nn.Linear(4, 5),
        ConstantBetaSwish(),
        nn.Linear(5, 2)
)
print(net)
# Sequential(
#  (0): Linear(in_features=4, out_features=5, bias=True)
#  (1): ConstantBetaSwish(beta=2.0)
#  (2): Linear(in_features=5, out_features=2, bias=True)
# )

Possiamo usare questa versione anche per vedere come si comporta Swish al variare di $\beta$:

Swish con beta costante

Al variare di $\beta$ la funzione assume numerose forme interessanti, passando dall'essere una funzione quasi lineare con $\beta$ molto piccolo, fino alla classica ReLU per $\beta$ molto alto. Sarebbe interessante poter usare tutte queste varianti all'interno dei nostri modelli - o ancora meglio, lasciare che sia l'ottimizzazione stessa a decidere quale usare.

Siete interessati? Proseguiamo!

Parte 3 di 3: Swish con parametro adattabile

Far sì che l'ottimizzazione selezioni un $\beta$ ottimale per ciascun neurone è abbastanza facile: basta dire a PyTorch che quei valori sono parametri del modello stesso, e verranno inclusi automaticamente nella fase di ottimizzazione. Niente di più facile:

class BetaSwish(nn.Module):
    def __init__(self, num_parameters=1):
        super(BetaSwish, self).__init__()

        self.num_parameters = num_parameters
        self.beta = torch.nn.Parameter(torch.ones(1, num_parameters))

    def forward(self, input):
        return input * torch.sigmoid(input * self.beta)

Qualche commento sul codice:

  1. A differenza di prima, dobbiamo specificare quanti neuroni compongono questo strato (num_parameters): questo perché dobbiamo inizializzare un parametro per ogni neurone.
  2. I parametri sono inseriti in un oggetto torch.nn.Parameter. Se ricordate la spiegazione nel tutorial precedente, Parameter è un wrapper di un tensore che identifica quali tensori in un modello devono essere allenati.
  3. Il significato di input * self.beta è leggermente diverso da prima: grazie al broadcasting, stiamo ora moltiplicando ogni colonna di input per un $\beta$ diverso.

Vediamo un esempio di modello costruito con la nuova funzione:

net = nn.Sequential(
        nn.Linear(4, 10),
        BetaSwish(10),
        nn.Linear(10, 3)
)

Se lo eseguiamo sull'esempio dello scorso tutorial (Iris), l'errore scende rapidamente a zero anche in questo caso:

Evoluzione funzione costo (Iris)

Ma la cosa interessante è vedere i valori risultanti di $\beta$:

net[1].beta.detach().numpy()
# array([[0.44259977, 0.9548798 , 0.19685858, 1.0720267 , 3.1051192 ,
#        0.09515426, 1.9494272 , 1.4172938 , 1.4043328 , 0.06701402]],
#      dtype=float32)

L'archittetura ottimale per questo problema, nonostante la sua semplicità, richiede un misto di funzioni di attivazione nello strato nascosto!

Bonus: implementare gradienti personalizzati

Fino a questo punto abbiamo usato solo combinazioni di funzioni predefinite di PyTorch. Nonostante questo copra buona parte di quanto è necessario in pratica, a volte siamo costretti (per diverse ragioni) ad usare funzioni esterne o definite da noi. Purtroppo, questo "rompe" il meccanismo di differenziazione automatica di PyTorch, che non può tracciare quello che avviene all'interno di queste funzioni: in questo caso, è necessario definire noi il gradiente del nuovo modulo.

Questo ci porta ad un livello più basso dei meccanismi di PyTorch, le Function. Le Function sono gli atomi indivisibili di PyTorch, che definiscono le operazioni elementari che è possibile eseguire (es., la somma di due tensori) ed i loro rispettivi gradienti. Ogni grafo che abbiamo definito finora è costituito, alla sua base, solo di tensori e funzioni.

Come esempio, supponiamo di voler reimplementare Swish-1, questa volta usando però la sigmoide definita nella libreria di scipy:

from scipy.special import expit

Vediamo l'implementazione in questo caso:

class SwishFunction(torch.autograd.Function):

    @staticmethod
    def forward(ctx, input):
        # Calcola la sigmoide (uscendo da autograd)
        input_sigmoid = torch.from_numpy(expit(input.detach().numpy()))
        # Salva tutto quello che serve per la back-propagation
        ctx.save_for_backward(input, input_sigmoid)
        return input * input_sigmoid

    @staticmethod
    def backward(ctx, grad_output):
        # Recupera i tensori salvati
        input,input_sigmoid, = ctx.saved_tensors
        # Calcola il gradiente
        grad_af = input_sigmoid + input * input_sigmoid * (1 - input_sigmoid)
        return grad_output * grad_af

Commentiamo le varie istruzioni del codice:

  1. Prima di tutto, com'è ovvio, ereditiamo da Function e non da Module. Una funzione richiede la definizione della sua forward pass (forward), e la rispettiva backward pass per la back-propagation (backward()). Si noti come entrambi siano metodi statici e non più metodi dinamici dell'oggetto.

  2. Per calcolare $\sigma(s)$, questa volta eseguiamo torch.from_numpy(expit(input.detach().numpy())): questo richiede di staccarci dal meccanismo di auto-differenziazione (con detach()) per invocare una funzione su array di NumPy.

  3. Poiché forward e backward sono metodi statici, è necessario un meccanismo per salvare tutti i valori utili per la back-propagation: questo è fornito da ctx.save_for_backward nella fase forward, e ctx.saved_tensors (per recuperarli) nella fase backward. In questo caso salviamo l'input passato alla funzione (necessario), ed anche il valore di $\sigma(s)$, che risparmia un po' di conti nella fase backward.

  4. La penultima riga di backward calcola il gradiente di Swish, che si ottiene facilmente derivando per parti: $\frac{d \text{Swish-1}(s)}{d s} = \sigma(s) + s \cdot \sigma'(s)$.

  5. Come detto prima, i gradienti vengono sempre calcolati all'interno di un meccanismo di back-propagation: grad_output è un tensore che mantiene i gradienti calcolati fino a quel punto da autograd. Poiché le funzioni di attivazione operano sui singoli elementi, il gradiente complessivo è dato dal gradiente di Swish-1 moltiplicato per grad_output.

  6. La fase backward in questo caso ritorna un solo tensore in output, ovvero il gradiente rispetto a input: nel caso la funzione avesse più input, sarebbe necessario ritornare il gradiente rispetto a ciascuno di essi.

Possiamo anche aggiungere dei controlli per verificare che sia effettivamente necessario calcolare i gradienti: si veda http://pytorch.org/docs/master/notes/extending.html.

Per verificare che i gradienti siano implementati correttamente, PyTorch mette a disposizione un test alle differenze finite, che verifica numericamente che i valori risultanti siano corretti:

from torch.autograd import gradcheck
input = (torch.randn(20, 20, requires_grad=True),)
test = gradcheck(SwishFunction.apply, input, eps=1e-2, atol=1e-2)
print(test)
# True

A questo punto non rimane altro che ridefinire il nostro modulo, questa volta sfruttando la nostra nuova funzione:

swish = SwishFunction.apply

class Swish(nn.Module):
    def forward(self, input):
        return swish(input)

Ed anche per questa volta è tutto! Nella prossima parte del tutorial, è tempo di passare alle reti convolutive ed agli strumenti per lavorare sulle immagini: Alle prese con PyTorch - Parte 4: Torchvision e Reti Convolutive.


Se questo articolo ti è piaciuto e vuoi tenerti aggiornato sulle nostre attività, ricordati che l'iscrizione all'Italian Association for Machine Learning è gratuita! Puoi seguirci anche su Facebook e su LinkedIn.

Previous Post Next Post