Siamese Networks: Hard Negative Mining

Hard Negative Mining

Now we will now implement the TripletLoss. Loss is composed of two terms. One term utilizes the mean of all the non duplicates, the second utilizes the closest negative. Our loss expression is then:

\begin{align} \mathcal{Loss_1(A,P,N)} &=\max \left( -cos(A,P) + mean_{neg} +\alpha, 0\right) \\ \mathcal{Loss_2(A,P,N)} &=\max \left( -cos(A,P) + closest_{neg} +\alpha, 0\right) \\ \mathcal{Loss(A,P,N)} &= mean(Loss_1 + Loss_2) \\ \end{align}

Here is a list of things we have to do:

  • As this will be run inside trax, use fastnp.xyz when using any xyz numpy function
  • Use fastnp.dot to calculate the similarity matrix \(v_1v_2^T\) of dimension batch_size x batch_size
  • Take the score of the duplicates on the diagonal fastnp.diagonal
  • Use the trax functions fastnp.eye and fastnp.maximum for the identity matrix and the maximum.

Imports

# python
from functools import partial

# pypi
from trax.fastmath import numpy as fastnp
from trax import layers

import jax
import numpy

Implementation

More Detailed Instructions

We'll describe the algorithm using a detailed example. Below, V1, V2 are the output of the normalization blocks in our model. Here we will use a batch_size of 4 and a d_model of 3. The inputs, Q1, Q2 are arranged so that corresponding inputs are duplicates while non-corresponding entries are not. The outputs will have the same pattern.

This testcase arranges the outputs, v1,v2, to highlight different scenarios. Here, the first two outputs V1[0], V2[0] match exactly - so the model is generating the same vector for Q1[0] and Q2[0] inputs. The second outputs differ, circled in orange, we set, V2[1] is set to match V2[**2**], simulating a model which is generating very poor results. V1[3] and V2[3] match exactly again while V1[4] and V2[4] are set to be exactly wrong - 180 degrees from each other, circled in blue.

Cosine Similarity

The first step is to compute the cosine similarity matrix or score in the code. This is \(V_1 V_2^T\) which is generated with fastnp.dot.

The clever arrangement of inputs creates the data needed for positive and negative examples without having to run all pair-wise combinations. Because Q1[n] is a duplicate of only Q2[n], other combinations are explicitly created negative examples or Hard Negative examples. The matrix multiplication efficiently produces the cosine similarity of all positive/negative combinations as shown above on the left side of the diagram. 'Positive' are the results of duplicate examples and 'negative' are the results of explicitly created negative examples. The results for our test case are as expected, V1[0]V2[0] match producing '1' while our other 'positive' cases (in green) don't match well, as was arranged. The V2[2] was set to match V1[3] producing a poor match at score[2,2] and an undesired 'negative' case of a '1' shown in grey.

With the similarity matrix (score) we can begin to implement the loss equations. First, we can extract \(\cos(A,P)\) by utilizing fastnp.diagonal. The goal is to grab all the green entries in the diagram above. This is positive in the code.

Closest Negative

Next, we will create the closest_negative. This is the nonduplicate entry in V2 that is closest (has largest cosine similarity) to an entry in V1. Each row, n, of score represents all comparisons of the results of Q1[n] vs Q2[x] within a batch. A specific example in our testcase is row score[2,:]. It has the cosine similarity of V1[2] and V2[x]. The closest_negative, as was arranged, is V2[2] which has a score of 1. This is the maximum value of the 'negative' entries (blue entries in the diagram).

To implement this, we need to pick the maximum entry on a row of score, ignoring the 'positive'/green entries. To avoid selecting the 'positive'/green entries, we can make them larger negative numbers. Multiply fastnp.eye(batch_size) with 2.0 and subtract it out of scores. The result is negative_without_positive. Now we can use fastnp.max, row by row (axis=1), to select the maximum which is closest_negative.

Mean Negative

Next, we'll create mean_negative. As the name suggests, this is the mean of all the 'negative'/blue values in score on a row by row basis. We can use fastnp.eye(batch_size) and a constant, this time to create a mask with zeros on the diagonal. Element-wise multiply this with score to get just the 'negative values. This is negative_zero_on_duplicate in the code. Compute the mean by using fastnp.sum on negative_zero_on_duplicate for axis=1 and divide it by (batch_size - 1) . This is mean_negative.

Now, we can compute loss using the two equations above and fastnp.maximum. This will form triplet_loss1 and triplet_loss2.

triple_loss is the fastnp.mean of the sum of the two individual losses.

def TripletLossFn(v1: numpy.ndarray, v2: numpy.ndarray,
                  margin: float=0.25) -> jax.interpreters.xla.DeviceArray:
    """Custom Loss function.

    Args:
       v1 (numpy.ndarray): Array with dimension (batch_size, model_dimension) associated to Q1.
       v2 (numpy.ndarray): Array with dimension (batch_size, model_dimension) associated to Q2.
       margin (float, optional): Desired margin. Defaults to 0.25.

    Returns:
       jax.interpreters.xla.DeviceArray: Triplet Loss.
    """
    # use fastnp to take the dot product of the two batches (don't forget to transpose the second argument)
    scores = fastnp.dot(v1, v2.T)
    # calculate new batch size
    batch_size = len(scores)
    # use fastnp to grab all postive =diagonal= entries in =scores=
    positive = fastnp.diagonal(scores)  # the positive ones (duplicates)
    # multiply =fastnp.eye(batch_size)= with 2.0 and subtract it out of =scores=
    negative_without_positive = scores - (fastnp.eye(batch_size) * 2.0)
    # take the row by row =max= of =negative_without_positive=. 
    # Hint: negative_without_positive.max(axis = [?])  
    closest_negative = fastnp.max(negative_without_positive, axis=1)
    # subtract =fastnp.eye(batch_size)= out of 1.0 and do element-wise multiplication with =scores=
    negative_zero_on_duplicate = (1.0 - fastnp.eye(batch_size)) * scores
    # use =fastnp.sum= on =negative_zero_on_duplicate= for =axis=1= and divide it by =(batch_size - 1)= 
    mean_negative = fastnp.sum(negative_zero_on_duplicate, axis=1)/(batch_size - 1)
    # compute =fastnp.maximum= among 0.0 and =A=
    # A = subtract =positive= from =margin= and add =closest_negative= 
    triplet_loss1 = fastnp.maximum(0, margin - positive + closest_negative)
    # compute =fastnp.maximum= among 0.0 and =B=
    # B = subtract =positive= from =margin= and add =mean_negative=
    triplet_loss2 = fastnp.maximum(0, (margin - positive) + mean_negative)
    # add the two losses together and take the =fastnp.mean= of it
    triplet_loss = fastnp.mean(triplet_loss1 + triplet_loss2)
    return triplet_loss
v1 = numpy.array([[0.26726124, 0.53452248, 0.80178373],[0.5178918 , 0.57543534, 0.63297887]])
v2 = numpy.array([[ 0.26726124,  0.53452248,  0.80178373],[-0.5178918 , -0.57543534, -0.63297887]])
triplet_loss = TripletLossFn(v2, v1)
print(f"Triplet Loss: {triplet_loss}")

assert triplet_loss == 0.5
Triplet Loss: 0.5

To make a layer out of a function with no trainable variables, use tl.Fn.

from functools import partial
def TripletLoss(margin=0.25):
    triplet_loss_fn = partial(TripletLossFn, margin=margin)
    return layers.Fn('TripletLoss', triplet_loss_fn)

Bundle It Up

Unfortunately trax does some kind of weirdness where it counts the arguments of the things you use as layers, so class-based stuff won't work (because it counts the self argument, giving it too many to expect). There might be a way to work around this, but it doesn't appear to be documented so this has to be done with only functions. That's not bad, it's just unexpected (and not well documented).

Imports

# python
from functools import partial

# from pypi
from trax.fastmath import numpy as fastmath_numpy
from trax import layers

import attr
import jax
import numpy
import trax

Triplet Loss

def triplet_loss(v1: numpy.ndarray,
             v2: numpy.ndarray, margin: float=0.25)-> jax.interpreters.xla.DeviceArray:
    """Calculates the triplet loss

    Args:
     v1: normalized batch for question 1
     v2: normalized batch for question 2

    Returns:
     triplet loss
    """
    scores = fastmath_numpy.dot(v1, v2.T)
    batch_size = len(scores)
    positive = fastmath_numpy.diagonal(scores)
    negative_without_positive = scores - (fastmath_numpy.eye(batch_size) * 2.0)
    closest_negative = fastmath_numpy.max(negative_without_positive, axis=1)
    negative_zero_on_duplicate = (1.0 - fastmath_numpy.eye(batch_size)) * scores
    mean_negative = fastmath_numpy.sum(negative_zero_on_duplicate, axis=1)/(batch_size - 1)
    triplet_loss1 = fastmath_numpy.maximum(0, margin - positive + closest_negative)
    triplet_loss2 = fastmath_numpy.maximum(0, (margin - positive) + mean_negative)
    return fastmath_numpy.mean(triplet_loss1 + triplet_loss2)

Triplet Loss Layer

Another not well documented limitation is that the function you create the layer from isn't allowed to take have default values, so if we want to allow the margin to have a default, we have to use partial to set the value before creating the layer…

def triplet_loss_layer(margin: float=0.25) -> layers.Fn:
    """Converts the triplet_loss function to a trax layer"""
    with_margin = partial(triplet_loss, margin=margin)
    return layers.Fn("TripletLoss", with_margin)

Check It Out

from neurotic.nlp.siamese_networks import triplet_loss_layer

layer = triplet_loss_layer()
print(type(layer))
<class 'trax.layers.base.PureLayer'>