Dal k-NN al Transformer in pochi passi (quasi)


Uno degli articoli scientifici più influenti dell'ultima decade è sicuramente Attention is All You Need.1 Come da titolo, l'obiettivo dell'articolo era semplice: mostrare come una componente delle reti neurali fino a quel momento di nicchia (neural attention, o semplicemente attenzione in questo post) bastava da sola a costruire architetture neurali estremamente sofisticate. La famiglia di modelli così ottenuti, i Transformer, sono oggi fondamentali in numerosi campi, dal natural language processing1 alle graph neural networks2 ed alla computer vision.3

Il meccanismo di attenzione è però meno misterioso di quanto possa sembrare a prima vista, e lo possiamo ritrovare (seppur in una forma rudimentale) in uno degli algoritmi più usati nel machine learning: il $k$-nearest neighbours (k-NN)! In questo post, vediamo quindi come costruire un meccanismo realistico di attenzione partendo proprio dal k-NN.

Un primo meccanismo d'attenzione

Il k-NN è tanto famoso quanto semplice: preso un training set $(x_i, y_i)$ di coppie input/output, ed un elemento $x$ su cui effettuare una predizione, scegliamo gli indici $\mathcal{N}(x)$ dei $k$ input più vicini (simili) a $x$, e prediciamo la media dei loro rispettivi output:

$$ f(x) = \frac{1}{k} \sum_{j \in \mathcal{N}(x)} y_j$$

Varianti per la classificazione considerano altri tipi di aggregazione, ma non sono di interesse qui.

Il k-NN classico assegna ad ogni elemento un peso uniforme, che risulta problematico nel caso di input particolarmente lontani. Una semplice variante (detta weighted k-NN) consiste nell'assegnare un peso variabile $s(x, x_j)$ che rappresenta la similarità tra $x$ ed $x_j$:

$$ f(x) = \sum_{j \in \mathcal{N}(x)} \color{red}{s(x, x_j)} \cdot y_j$$

A differenza di prima, l'output non è più correttamente normalizzato, in quanto la somma dei pesi non è necessariamente 1. Per rimediare, possiamo includere un termine esplicito di normalizzazione (equivalente alla moltiplicazione per $\frac{1}{k}$ di prima):

$$ f(x) = \color{red}{\frac{1}{\sum_{j \in \mathcal{N}(x)} s(x, x_j)}} \sum_{j \in \mathcal{N}(x)} s(x, x_j) \cdot y_j$$

Una scelta piuttosto comune per calcolare la similarità è l'inverso della distanza Euclidea:

$$ s(x, x_j) = \frac{1}{\lVert x - x_j \rVert}$$

Una alternativa comune nelle reti neurali è il prodotto scalare:

$$ s(x, x_j) = x^Tx_j$$

Una seconda alternativa comune è l'uso di una funzione softmax per ottenere la normalizzazione. Nell'equazione sopra, questo è equivalente ad usare $\exp(s(x, x_j))$ al posto di $s(x, x_j)$ (con una qualsiasi misura di similarità):

$$ f(x) = \frac{1}{\sum_{j \in \mathcal{N}(x)} \color{red}{\exp(}s(x, x_j)\color{red}{)}} \sum_{j \in \mathcal{N}(x)} \color{red}{\exp(}s(x, x_j)\color{red}{)} \cdot y_j$$

Dal punto di vista delle reti neurali, quello che abbiamo costruito non è altro che un rudimentale meccanismo di attenzione! Nella sua generalità, ed usando la terminologia delle reti neurali, tale meccanismo ci permette di aggregare un insieme di valori (gli output $y_j$ del training set) sulla base di una serie di confronti tra le chiavi (gli input $x_j$ del training set) ed una query di riferimento (il nuovo input $x$ da predire), come mostrato sotto.

Figura 1: Schema semplificato del weighted k-NN visto come meccanismo di attenzione.

Introduciamo qualche parametro

Una differenza notevole tra un k-NN, seppur pesato, ed una rete neurale, è la mancanza di parametri allenabili: una volta fissato il training set e la misura di similarità, le predizioni dell'algoritmo sono determinate senza nessuna procedura di allenamento.

Il modo più semplice di introdurre dei parametri è di considerare una funzione di similarità allenabile. Nei meccanismi di attenzione più comuni (dot-product attention) questo si ottiene semplicemente proiettando la query e le chiavi con una matrice allenabile, e calcolando la similarità tramite il loro prodotto scalare:

$$ s(x, x_j) = {\underbrace{\left( Wx \right)}_{\text{query } q}}^T \underbrace{\left( Wx_j \right)}_{\text{chiave } k_j}$$

In pratica, è comune aggiungere anche un termine $\frac{1}{\sqrt{K}}$, dove $K$ è la dimensione di $Wx$, per normalizzare la similarità.

Come apprendere la matrice $W$? La difficoltà di applicare una discesa al gradiente in questo caso deriva soprattutto dal dover scegliere, ad ogni istante, i $k$ elementi più simili rispetto a $s(\cdot, \cdot)$, una operazione di cui non è possibile calcolare il gradiente. Nonostante sia possibile approssimarla,4 è ancora più facile ignorare del tutto la scelta dei vicini, e calcolare la media pesata rispetto a tutto il dataset:

$$ f(x) = \frac{1}{\color{red}{\sum_{j}} s(x, x_j)} \color{red}{\sum_{j}} s(x, x_j) \cdot y_j$$

Possiamo ora allenare la matrice minimizzando un qualche tipo di errore, come ad esempio l'errore quadratico medio:

$$ L(W) = \sum_i (y_i - f(x_i))^2$$

Come vedremo nella prossima sezione, l'algoritmo così ottenuto comincia ad essere molto simile ad un meccanismo realistico di attenzione. Per completezza, va menzionato che nel caso del k-NN, l'idea di apprendere una funzione di distanza ha dato vita ad una ampia letteratura.5

Un'altra nota interessante: scegliendo una funzione di similarità a kernel $s(x, x_j) = \phi(x)^T\phi(x_j)$ ricadiamo invece nel mondo dei kernel methods (tra cui, ad esempio, le support vector machine). Una delle linee di ricerca dell'ultimo anno cerca proprio di combinare questi due mondi per migliorare le prestazioni dei Transformer.6

Un k-NN con rappresentazioni intermedie

Cosa manca al nostro algoritmo? Finora, abbiamo usato dei meccanismi allenabili per imparare a pesare correttamente le etichette di input 'simili'. In questo senso, il nostro algoritmo è più simile ad un metodo di label propagation (per chi avesse familiarità con i grafi) che ad una rete neurale. Intuitivamente, questo può funzionare solo per input semplici.

Da un algoritmo neurale, invece, ci aspettiamo la capacità di apprendere una serie di rappresentazioni intermedie sempre più astratte (gli strati nascosti della rete) prima di giungere alla predizione finale.

Curiosamente, è abbastanza facile modificare il nostro algoritmo per gestire quest'ultima idea. L'idea di fondo è di imparare ad aggregare gli input $x_j$ invece degli output $y_j$, ottenendo quindi una rappresentazione intermedia come somma pesata degli input di partenza. In questo caso, è più comune costruire le nostre query, chiavi, e valori usando tre diverse matrici di proiezione $W^{(q)}$, $W^{(k)}$, e $W^{(v)}$ (mentre prima ne avevamo usata una sola per chiavi e query):

$$ \color{red}{q} = W^{(q)}x \,, \,\,\, \color{blue}{k_j} = W^{(k)} x_j \,, \,\,\, \color{green}{v_j} = W^{(v)} x_j$$

Usiamo quindi questi nuovi valori nella funzione di attenzione di prima:

$$ \text{Att}(x, \{x_j\}) = \frac{1}{\sum_{j} s(\color{red}{q}, \color{blue}{k_j})} \sum_{j} s(\color{red}{q}, \color{blue}{k_j}) \cdot \color{green}{v_j}$$

Il risultato è simile: usiamo i confronti tra la query $\color{red}{q}$ e le chiavi $\color{blue}{k_j}$ per aggregare i valori $\color{green}{v_j}$ a nostra disposizione, con la differenza che in questo caso l'output è una nuova rappresentazione di $x$ e non più una predizione. Aggiungendo una non-linearità e ripetendo questo processo più volte, otteniamo una rete neurale interamente basata su un meccanismo di attenzione! Poiché usiamo gli stessi dati di partenza per generare query, chiavi, e valori, si parla di self-attention.

Vettorizzare il processo

Un aspetto interessante del meccanismo di attenzione è di essere molto semplice da vettorizzare. Supponiamo infatti di voler ottenere le nuove rappresentazioni per l'intero training set, che possiamo scrivere in forma compatta come una unica matrice $X$ in cui l'$i$-esima riga corrisponde a $x_i$. Possiamo calcolare le nuove query, valori, e chiavi semplicemente come:

$$ \color{red}{Q} = XW^{(q)} \,, \,\,\, \color{blue}{K} = XW^{(k)} \,, \,\,\, \color{green}{V} = XW^{(v)} \,,$$

dove, a differenza di prima, abbiamo ora varie query (le singole righe di $Q$), ciascuna corrispondente ad un input. Possiamo quindi riscrivere il meccanismo di attenzione sull'intera matrice $X$ come:

$$ \text{Att}(X) = \text{softmax}(\color{red}{Q}\color{blue}{K}^T) \color{green}{V} \,,$$

dove, per coerenza con quanto fatto di solito nelle reti neurali, abbiamo esplicitato l'uso della softmax. Questa formulazione è così comune da essere ormai presente nella maggior parte dei framework di deep learning (es., in PyTorch), e l'abbiamo riassunta nella figura sotto.

Figura 2: schema del meccanismo di self-attention.

Possiamo adesso cominciare a costruire reti a più strati aggiungendo una non-linearità $\phi$ (es., una ReLU) tra due componenti di self-attention:

$$ H = \text{Att}(\phi(\text{Att}(X))) \,,$$

ovviamente assumendo che le due componenti abbiano parametri allenabili indipendenti.

Graph attention network e transformer

In questo post, partendo dal k-NN, abbiamo ottenuto un algoritmo che aggrega tutti gli input del training set per ottenere una predizione tramite meccanismi di self-attention. Per gli appassionati delle graph neural network, quello che abbiamo costruito è in effetti un prototipo molto semplice di una graph attention network.2 Nel deep learning è comune usare il meccanismo di attenzione per aggregare gli elementi di una sequenza come, ad esempio, i token di una frase nel natural language processing. È proprio in questo contesto che il meccanismo di attenzione ricopre un ruolo fondamentale in quanto è alla base dei Transformer di cui parleremo in dettaglio in un altro post.

Vale la pena fare qualche osservazione conclusiva su quanto ottenuto:

  1. Formalmente, quello che abbiamo sviluppato è un meccanismo single-head. Una variante multi-head si ottiene replicando più volte il processo di attenzione in parallelo, e combinando poi i risultati (ci ritorneremo).
  2. Il dot-product non è l'unico meccanismo di attenzione possibile. paperswithcode ha una bella overview di alcuni dei meccanismi più comuni proposti in letteratura.

Nei prossimi articoli, parleremo più nel dettaglio dei Transformer, e dell'implementazione di queste tecniche.

Note conclusive

Questo articolo è ispirato ad uno dei nuovi capitoli di Dive into Deep Learning, che considera un setup simile con uno stimatore di Nadaraya-Watson per un problema di regressione. L'autore ringrazia inoltre tutti quelli che hanno voluto esprimere feedback e commenti sull'articolo.


Se questo articolo ti è piaciuto e vuoi tenerti aggiornato sulle nostre attività, puoi seguirci anche su Facebook, LinkedIn, Twitter, Telegram, e Discord.


  1. Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A.N., Kaiser, L. and Polosukhin, I., 2017. Attention is all you need. NeurIPS. 

  2. Veličković, P., Cucurull, G., Casanova, A., Romero, A., Lio, P. and Bengio, Y., 2017. Graph attention networks. arXiv preprint arXiv:1710.10903. 

  3. Dosovitskiy, A., Beyer, L., Kolesnikov, A., Weissenborn, D., Zhai, X., Unterthiner, T., Dehghani, M., Minderer, M., Heigold, G., Gelly, S. and Uszkoreit, J., 2020. An image is worth 16x16 words: Transformers for image recognition at scale. arXiv preprint arXiv:2010.11929. 

  4. Plötz, T. and Roth, S., 2018. Neural nearest neighbors networks. Advances in Neural Information Processing Systems, 31, pp.1087-1098. 

  5. Weinberger, K.Q. and Saul, L.K., 2009. Distance metric learning for large margin nearest neighbor classification. Journal of machine learning research, 10(2). 

  6. Choromanski, K., Likhosherstov, V., Dohan, D., Song, X., Gane, A., Sarlos, T., Hawkins, P., Davis, J., Mohiuddin, A., Kaiser, L. and Belanger, D., 2020. Rethinking attention with performers. arXiv preprint arXiv:2009.14794. 

Previous Post