Deep Dive into the LSTM-CRF Model

Author:Murphy  |  View: 23167  |  Time: 2025-03-23 12:26:38

In the rapidly evolving field of natural language processing, Transformers have emerged as dominant models, demonstrating remarkable performance across a wide range of sequence modelling tasks, including part-of-speech tagging, named entity recognition, and chunking. Prior to the era of Transformers, Conditional Random Fields (CRFs) were the go-to tool for sequence modelling and specifically linear-chain CRFs that model sequences as directed graphs while CRFs more generally can be used on arbitrary graphs.

This article will be broken down as follows:

  1. Introduction
  2. Emission and Transition scores
  3. Loss function
  4. Efficient estimation of partition function through Forward Algorithm
  5. Viterbi Algorithm
  6. Full LSTM-CRF code
  7. Drawbacks and Conclusions

Introduction

The implementation of CRFs in this article in based on this excellent tutorial. Please note that it's definitely not the most efficient implementation out there and also lacks batching capability, however, it's relatively simple to read and understand and because the aim of this tutorial is to get our heads around the internal working of CRFs it is perfectly suitable for us.

Emission and Transition scores

In sequence tagging problems, we deal with a sequence of input data elements, such as the words within a sentence, where each element corresponds to a specific label or category. The primary objective is to correctly assign the appropriate label to each individual element. Within the CRF-LSTM model we can identify two key components to do this: emission and transition probabilities. Note we will actually deal with scores in log space instead of probabilities for numerical stability:

  1. Emission scores relate to the likelihood of observing a particular label for a given data element. In the context of Named Entity Recognition, for example, each word in a sequence is affiliated with one of three labels: Beginning of an entity (B), Intermediate word of an entity (I), or a word outside to any entity (O). Emission probabilities quantify the probability of a specific word being associated with a particular label. This is expressed mathematically as P(y_i | x_i), where y_i denotes the label and x_i represents the input word.
  2. Transition scores, conversely, describe the likelihood of transitioning from one label to another within a sequence. These scores enable the modeling of dependencies between consecutive labels. By capturing these dependencies, transition scores contribute to the coherence and consistency of predicted label sequences. They are denoted as P(yi | y(i-1)), with yi representing the current label and y(i-1) the previous label in the sequence.

The synergy of these two components results in a robust sequence tagging model.

Figure by author (figure 1)

To assign emission score to a given word, various feature functions can be employed, considering aspects like the word's context, its shape such as capitalization patterns – if a word starts with a capital letter and its not at the beginning of a sentence it's likely to be the beginning of an entity, morphological features (including prefixes, suffixes, and stems), among others. Defining such features can be a labor-intensive and time-consuming. This is why many practitioners opt for a bidirectional LSTM model (I have explained LSTM in this article in case you want to refresh your knowledge) that can compute emission scores based on contextual information for each word without manually defining any features.

Subsequently, having obtained the emission scores from the LSTM, we construct a CRF layer to learn the transition scores. The CRF layer leverages the emission scores generated by the LSTM to optimize the assignment of the best label sequence while considering label dependencies.

Loss function

This combined model (LSTM + CRF) can be trained end-to-end (we are not going to do it in this tutorial though) maximizing the probability of tags sequence given the inputs – P(y|x), which is the same as minimizing the negative log-likelihood of P(y|x):

  • x is the input and y is the label
  • _E(yi|x) be the emission scores of label yi at position i according to the LSTM model
  • T(y(i-1), yi) is the CRF's learned transition scores
  • Z(x) be the partition function, which is a normalization factor that ensures that the probabilities sum to 1 over all possible label sequences

After applying the log it becomes:

The first term is the log of the partition function, the second term quantifies how well the LSTM's emission scores match the true labels while the third accounts for the likelihood of label transitions according to the CRF.

Calculating Z(x) can be computationally demanding due to the necessity of iterating through all possible label sequences within a given input sequence:

This process exhibits an exponential time complexity. To address this computational challenge, we are going to use the forward algorithm that has polynomial time complexity. This method offers an efficient means of computing the partition function without the requirement to explicitly evaluate every possible label sequence. The forward algorithm relies on dynamic programming, specifically utilizing the forward-backward algorithm, which efficiently calculates the partition function in a time complexity that scales linearly with the length of the sequence.

Efficient estimation of partition function through Forward Algorithm

To understand how the algorithm works let's start from a simple numerical example. We are now going to define some toy emission and transition scores:

emissions = torch.tensor([    [-0.4554, 0.4215, 0.12],
                              [0.4058, -0.5081, 0.21],
                              [0.2058, 0.5081, 0.02]    ])

transitions = torch.tensor([ [0.3,  0.1, 0.1],
                             [0.8,  0.2, 0.2],
                             [0.3,  0.7, 0.7] ])

In the full code in the last section we assume that each sequence starts with tag 0 and ends with tag 1. Normally, they correspond to _STARTTAG (beginning of sequence tag aka BOS) and _STOPTAG (end of sequence tag aka EOS) and we usually set the transition probability from _STOPTAG to any other tag to 0 and transition probability to _STARTTAG from any other tag to 0 too. In this toy example we will ignore these rules to not over-complicate things, thus we will allow tags 0 and 1 to be also in the middle of the sequence.

Also, remember that we work in exponential space for emission and transition scores – these scores can then be normalized and transformed into probabilities. Because we work in exponential space, we sum the scores rather than multiplying them.

Let's first see how to compute the partition function in an inefficient way going through all the possible paths:

string_computs = []
dim_seq_len, dim_lbl = emissions.shape
scr = 0
all_s = {}

for s1 in range(dim_lbl):
    for s2 in range(dim_lbl):
        for s3 in range(dim_lbl):
            # assume sequences start with tag 0 and end with tag 1
            s = (0,) + (s1,)  + (s2,)  + (s3,) + (1,)
            string_computs.append(str(s))
            # note we are in exponential space thus we sum probabilities
            em_sum = 0
            for i in range(len(s)-2):
                em_sum += emissions[i, s[1:-1][i]].detach().numpy()

            scr_temp = 0
            for i, _ in enumerate(range(len(s)-1)): 
                scr_temp += transitions[s[i], s[i+1]].detach().numpy()

            scr += np.exp(em_sum + scr_temp)
            all_s[s] = em_sum + scr_temp

print(np.log(scr))
"""
Output:

5.0607
"""

As we see to manually computing the partition function we need to loop through all possible sequences which are in this case:

print(string_computs)
"""
[(0, 0, 0, 0, 1),
 (0, 0, 0, 1, 1),
 (0, 0, 0, 2, 1),
 (0, 0, 1, 0, 1),
 (0, 0, 1, 1, 1),
 (0, 0, 1, 2, 1),
 (0, 0, 2, 0, 1),
 (0, 0, 2, 1, 1),
 (0, 0, 2, 2, 1),
 (0, 1, 0, 0, 1),
 (0, 1, 0, 1, 1),
 (0, 1, 0, 2, 1),
 (0, 1, 1, 0, 1),
 (0, 1, 1, 1, 1),
 (0, 1, 1, 2, 1),
 (0, 1, 2, 0, 1),
 (0, 1, 2, 1, 1),
 (0, 1, 2, 2, 1),
 (0, 2, 0, 0, 1),
 (0, 2, 0, 1, 1),
 (0, 2, 0, 2, 1),
 (0, 2, 1, 0, 1),
 (0, 2, 1, 1, 1),
 (0, 2, 1, 2, 1),
 (0, 2, 2, 0, 1),
 (0, 2, 2, 1, 1),
 (0, 2, 2, 2, 1)]
"""

This is not very efficient as we can see we compute the same sequences multiple times like (0,0), (0,1), (0,2) are computed 9 times or (0,0,0), (0,0,1), (0,0,2) are computed 3 times and so on. Thus, instead of recomputing the same thing multiple times we can compute them once and store them for other calculations:

START_TAG = 0
STOP_TAG = 1
dim_seq_len, dim_lbl = emissions.shape
init_alphas = transitions[START_TAG] + emissions[START_TAG]

# Wrap in a variable so that we will get automatic backprop
alphas = init_alphas
print(alphas)

for emission in emissions[1:]:
    alphas_t = []  # The forward tensors at this timestep
    for next_tag in range(dim_lbl):
        # emission score for the next tag
        emit_score = emission[next_tag].view(1, -1).expand(1, dim_lbl)
        # transition score from any previous tag to the next tag
        trans_score = transitions[:, next_tag].view(1, -1)
        # combine current scores with previous alphas 
        # since alphas are in log space (see logsumexp below),
        # we add them instead of multiplying
        next_tag_var = alphas + trans_score + emit_score
        print(f"Scores {next_tag} - {next_tag_var} |-| {torch.logsumexp(next_tag_var, dim=1)}")

        alphas_t.append(torch.logsumexp(next_tag_var, 1).view(1))

    alphas = torch.cat(alphas_t).view(1, -1)
    print(alphas)

terminal_alphas = alphas + transitions[:, STOP_TAG]
print(terminal_alphas)
alphas = torch.logsumexp(terminal_alphas, 1)
print(alphas)

"""
Outputs:

tensor([-0.1554,  0.5215,  0.2200])
Scores 0 - tensor([[0.5504, 1.7273, 0.9258]]) |-| tensor([2.2908])
Scores 1 - tensor([[-0.5635,  0.2134,  0.4119]]) |-| tensor([1.1990])
Scores 2 - tensor([[0.1546, 0.9315, 1.1300]]) |-| tensor([1.9171])
tensor([[2.2908, 1.1990, 1.9171]])
Scores 0 - tensor([[2.7966, 2.2048, 2.4229]]) |-| tensor([3.6038])
Scores 1 - tensor([[2.8989, 1.9071, 3.1252]]) |-| tensor([3.8639])
Scores 2 - tensor([[2.4108, 1.4190, 2.6371]]) |-| tensor([3.3758])
tensor([[3.6038, 3.8639, 3.3758]])
tensor([[3.7038, 4.0639, 4.0758]])

tensor([5.0607])

"""

Let's try to break it down to understand what is going on:

alpha0_0 = transitions[START_TAG , 0] + emissions[START_TAG , 0] # (0, 0) 
alpha0_1 = transitions[START_TAG , 1] + emissions[START_TAG , 1] # (0, 1) 
alpha0_2 = transitions[START_TAG , 2] + emissions[START_TAG , 2] # (0, 2)
print(alpha0_0, alpha0_1, alpha0_2)

# all combos of len 3 that finish with 0, so (0, 0, 0), (0, 1, 0), (0, 2, 0)
alpha1_0 = torch.logsumexp(torch.tensor([(eval(f"alpha0_{i}") + transitions[i, 0] + emissions[1, 0]) for i in range(dim_lbl)]).unsqueeze(0), 1)
# all combos of len 3 that finish with 1, so (0, 0, 1), (0, 1, 1), (0, 2, 1)
alpha1_1 = torch.logsumexp(torch.tensor([(eval(f"alpha0_{i}") + transitions[i, 1] + emissions[1, 1]) for i in range(dim_lbl)]).unsqueeze(0), 1)
alpha1_2 = torch.logsumexp(torch.tensor([(eval(f"alpha0_{i}") + transitions[i, 2] + emissions[1, 2]) for i in range(dim_lbl)]).unsqueeze(0), 1)
print(alpha1_0, alpha1_1, alpha1_2)

# all combos of len 4 that finish with 0, so (0, 0, 0, 0), (0, 0, 1, 0), (0, 0, 2, 0), .. , (0, 2, 1, 0) , (0, 2, 2, 0)
alpha2_0 = torch.logsumexp(torch.tensor([(eval(f"alpha1_{i}") + transitions[i, 0] + emissions[2, 0]) for i in range(dim_lbl)]).unsqueeze(0), 1)
alpha2_1 = torch.logsumexp(torch.tensor([(eval(f"alpha1_{i}") + transitions[i, 1] + emissions[2, 1]) for i in range(dim_lbl)]).unsqueeze(0), 1)
# all combos of len 4 that finish with 2, so (0, 0, 0, 2), (0, 0, 1, 2), (0, 0, 2, 2), ..,(0, 2, 1, 2) , (0, 2, 2, 2)
alpha2_2 = torch.logsumexp(torch.tensor([(eval(f"alpha1_{i}") + transitions[i, 2] + emissions[2, 2]) for i in range(dim_lbl)]).unsqueeze(0), 1)
print(alpha2_0, alpha2_1, alpha2_2)

alpha3_0 = torch.logsumexp(torch.tensor([(eval(f"alpha2_{i}") + transitions[i, STOP_TAG]) for i in range(dim_lbl)]).unsqueeze(0), 1)
print(alpha3_0)

"""
Outputs:

tensor(-0.1554) tensor(0.5215) tensor(0.2200)
tensor([2.2908]) tensor([1.1990]) tensor([1.9171])
tensor([3.6038]) tensor([3.8639]) tensor([3.3758])
tensor([5.0607])

"""

Each alpha represents the sum of all paths with a given final token. For example, alpha2_0 is:

string_computs = []
dim_seq_len, dim_lbl = emissions.shape
scr = 0
all_s = {}

# let's recompute all sequences scores first
for s1 in range(dim_lbl):
    for s2 in range(dim_lbl):
        for s3 in range(dim_lbl):
            s = (START_TAG ,) + (s1,)  + (s2,)  + (s3,) 
            string_computs.append(str(s))

            em_sum = 0
            for i in range(len(s)-1):
                em_sum += emissions[i, s[1:][i]].detach().numpy()

            scr_temp = 0
            for i, _ in enumerate(range(len(s)-1)): 
                scr_temp += transitions[s[i], s[i+1]].detach().numpy()

            scr += np.exp(em_sum + scr_temp)
            all_s[s] = em_sum + scr_temp

# now let's use `all_s` from above
cumsum = []
for e in all_s.keys():
    if e[-1] == 0:
        print(e, all_s[e])
        cumsum.append(all_s[e])

# sum all probabilities scores   
print(torch.logsumexp(torch.tensor(cumsum).unsqueeze(0), 1))

"""
Outputs:

(0, 0, 0, 0) 1.0562000572681427
(0, 0, 1, 0) 0.44230005890130997
(0, 0, 2, 0) 0.6604000255465508
(0, 1, 0, 0) 2.2331000342965126
(0, 1, 1, 0) 1.219200037419796
(0, 1, 2, 0) 1.4373000040650368
(0, 2, 0, 0) 1.431600034236908
(0, 2, 1, 0) 1.4177000224590302
(0, 2, 2, 0) 1.635799989104271

tensor([3.6038], dtype=torch.float64)

"""

This way, say we are at alpha1_0, i.e., we followed the sequence (0,0,0) so far. Here we have 3 choices: (0,0,0,0), (0,0,0,1), (0,0,0,2). Instead of recomputing (0,0,0) 3 times, we only compute it once and then add (0),(1) or (2) accordingly.

We can illustrate the whole process as below:

Figure by author (figure 2)

Now that we have efficiently computed the partition function Z(x) we need to compute the second part of the loss given by:

Our emission and transition scores are already in log space thus we simply do:

score = torch.zeros(1)
tags = torch.cat([torch.tensor([START_TAG], dtype=torch.long), tags[0]])
for i, emission in enumerate(emissions):
    score = score + transitions[tags[i], tags[i+1]] + emission[tags[i+1]]
score = score + transitions[tags[-1], STOP_TAG]

By adding these two parts together we can compute the negative log-likelihood loss to train CRF and LSTM coefficients.

Now that we understood how to efficiently calculate the loss and train the model, what about Inference? At inference time we need to find the sequence with the highest probability which has a similar problem to the calculation of partition function – it has exponential time because we need to loop through all possible sequences. The solution is to use the recursive power of another similar algorithm called Viterbi Algorithm.

Viterbi Algorithm (Inference)

As previously said, Viterbi algorithm is needed to select the sequence with highest probability (see the example in figure1 – CRF layer) given the estimated emission and transition scores from the data. Similarly to the forward algorithm, Viterbi algorithm allows us to estimate it efficiently instead of computing all possible sequences score and selecting the highest score one at the end. Let's again start from some toy emission and transition scores data:

transitions = torch.tensor([  [-1.0000e+04, -9.3,  0.6,  0.1,  0.6],
                              [-1.0000e+04, -1.0000e+04, -1.0000e+04, -1.0000e+04, -1.0000e+04],
                              [-1.0000e+04, -2,  -6,  0.9, 0.1],
                              [-1.0000e+04, 0.2,  -4, -5, 0.3],
                              [-1.0000e+04, 0.5, 0.3, -6.,  0.2]   ]).float()
transitions = torch.nn.Parameter(transitions)

emissions = torch.tensor([  [-0.2196,  -1.4047,  0.9992, 0.1948,  0.11],
                            [-0.2401,  0.4565,  0.3464, -0.1856,  0.2622],
                            [-0.3361,  0.1828, -0.3463, -0.2874,  0.2696]   ])

The inefficient solution looks like below where we loop through all possible sequences and then select the one with the highest score:

# first recompute scores for all sequences `all_s` 
dim_seq_len, dim_lbl = emissions.shape
scr = 0
all_s = {}
for s1 in range(dim_lbl):
    for s2 in range(dim_lbl):
        for s3 in range(dim_lbl):

            s = (START_TAG ,) + (s1,) + (s2,) + (s3,) + (STOP_TAG,)

            em_sum = 0
            for i in range(len(s)-2):
                em_sum += emissions[i, s[1:-1][i]].detach().numpy()

            scr_temp = 0
            for i, _ in enumerate(range(len(s)-1)):
                scr_temp += transitions[s[i], s[i+1]].detach().numpy()

            scr += np.exp(em_sum + scr_temp)
            all_s[s] = em_sum + scr_temp

sorted(all_s.items(), key=lambda x: x[1])[::-1]

"""
Most likely sequence is 
(0,2,3,4,1)  

[((0, 2, 3, 4, 1), 3.383200004696846),
 ((0, 2, 4, 4, 1), 2.931000016629696),
 ((0, 4, 2, 4, 1), 2.2260000333189964),
 ((0, 4, 2, 3, 1), 2.1689999997615814),
 ((0, 4, 4, 4, 1), 2.141800031065941),
 ((0, 3, 4, 4, 1), 1.8266000226140022),
 ((0, 2, 4, 2, 1), -0.08489998430013657),
 ((0, 4, 4, 2, 1), -0.8740999698638916),
 ((0, 3, 4, 2, 1), -1.1892999783158302),
 ((0, 3, 2, 4, 1), -2.489199995994568),
 ((0, 3, 2, 3, 1), -2.546200029551983),
 ...
 ((0, 1, 1, 1, 1), -30010.065400242805),
 ((0, 0, 0, 0, 1), -30010.095800206065)]

"""

The efficient way is to use Viterbi algorithm which is somehow similar to Forward algorithm but with some differences – if before computing the partition function Z(X) the scope of alphas was to find the total probability distribution over all sequences and thus alphas were sum of those probability distributions so far in the sequence, in Viterbi alphas are needed to follow the sequence with highest probability. Because of this, instead of doing the sum at each alpha node we take the maximum discarding all the other paths as if we select the node with maximum alpha at any intermediate stage, because we do exact same calculations after a given node we would have a lower final score if we do not select the maximum at this node.

Let's try to break it down. We are going to ignore _STARTTAG (0) and _STOPTAG (1) nodes in the illustration below to make it more clear. We also know that for how we built the transition matrix, we would never have a solution that has _STARTTAG or _STOPTAG in intermediate nodes:

Figure by author (figure 3)

As we go from BOS to the next node (1st layer), we cannot discard any node at this stage because the sequence with maximum probability can pass by any node.

emission = emission.unsqueeze(0) 
# we only compute these combos once
# even if in `all_s` below they appear multiple times.
# Specifically, they appear 9 times each
alpha0_0 = transitions[START_TAG , 0] + emissions[START_TAG , 0] # (0, 0) 
alpha0_1 = transitions[START_TAG , 1] + emissions[START_TAG , 1] # (0, 1) 
alpha0_2 = transitions[START_TAG , 2] + emissions[START_TAG , 2] # (0, 2)
alpha0_3 = transitions[START_TAG , 3] + emissions[START_TAG , 3] # (0, 2)
alpha0_4 = transitions[START_TAG , 4] + emissions[START_TAG , 4] # (0, 2)

print(alpha0_0, alpha0_1, alpha0_2, alpha0_3,alpha0_4)
"""
Outputs:

tensor(-10000.2197, grad_fn=) 
tensor(-10.7047, grad_fn=) 
tensor(1.5992, grad_fn=) 
tensor(0.2948, grad_fn=) 
tensor(0.7100, grad_fn=)
"""

In the 2nd layer the situation is different – we see that we can go to node 2 from all the previous nodes 2, 3 and 4. Here, we only select the path that leads to maximum alpha because we do exact same calculations for any node after this and we would have a lower final score if we did not select the maximum at this node.

alpha1_0_val, alpha1_0_idx = torch.max(torch.tensor([(eval(f"alpha0_{i}") + transitions[i, 0] + emissions[1, 0]) for i in range(5)]).unsqueeze(0), 1)
alpha1_1_val, alpha1_1_idx = torch.max(torch.tensor([(eval(f"alpha0_{i}") + transitions[i, 1] + emissions[1, 1]) for i in range(5)]).unsqueeze(0), 1)
alpha1_2_val, alpha1_2_idx = torch.max(torch.tensor([(eval(f"alpha0_{i}") + transitions[i, 2] + emissions[1, 2]) for i in range(5)]).unsqueeze(0), 1)
alpha1_3_val, alpha1_3_idx = torch.max(torch.tensor([(eval(f"alpha0_{i}") + transitions[i, 3] + emissions[1, 3]) for i in range(5)]).unsqueeze(0), 1)
alpha1_4_val, alpha1_4_idx = torch.max(torch.tensor([(eval(f"alpha0_{i}") + transitions[i, 4] + emissions[1, 4]) for i in range(5)]).unsqueeze(0), 1)

print("*"*100)
for tag in range(5):
  print(f"Scores to tag {tag}")
    print(torch.tensor([(eval(f"alpha0_{i}") + transitions[i, tag] + emissions[1, tag]) for i in range(5)]).unsqueeze(0))
    print("Max index alpha : ", eval(f"alpha1_{tag}_idx").item())
    print()

print("Maximum Values:")
print(alpha1_0_val, alpha1_1_val, alpha1_2_val, alpha1_3_val, alpha1_4_val)
print("Corresponding Indices:")
print(alpha1_0_idx, alpha1_1_idx, alpha1_2_idx, alpha1_3_idx, alpha1_4_idx)
print("*"*100)

"""
Outputs:

****************************************************************************************************
Scores to tag 0
tensor([[-20000.4590, -10010.9453,  -9998.6406,  -9999.9453,  -9999.5303]])
Max index alpha :  2

Scores to tag 1
tensor([[-1.0009e+04, -1.0010e+04,  5.5700e-02,  9.5130e-01,  1.6665e+00]])
Max index alpha :  4

Scores to tag 2
tensor([[-9.9993e+03, -1.0010e+04, -4.0544e+00, -3.3588e+00,  1.3564e+00]])
Max index alpha :  4

Scores to tag 3
tensor([[-1.0000e+04, -1.0011e+04,  2.3136e+00, -4.8908e+00, -5.4756e+00]])
Max index alpha :  2

Scores to tag 4
tensor([[-9.9994e+03, -1.0010e+04,  1.9614e+00,  8.5700e-01,  1.1722e+00]])
Max index alpha :  2

Maximum Values:
tensor([-9998.6406]) tensor([1.6665]) tensor([1.3564]) tensor([2.3136]) tensor([1.9614])
Corresponding Indices:
tensor([2]) tensor([4]) tensor([4]) tensor([2]) tensor([2])
****************************************************************************************************

"""

The calculations above tell us that if we want for example to transition at layer 2 to tag 3 (Scores for tag 3), the best tag to do it is from tag 2 in the previous layer as it will lead to maximum final score at the end. Similarly, if we want to transition at layer 2 to tag 2 (Scores to tag 2) the best tag to do it is from tag 4 and so on. This is also illustrated in the graph – transition to tag 2 in layer 2 we need to choose tag 4 in layer 1, to tag 3 tag 2 and to tag 4 tag 2. This is what the red lines from layer 1 to layer 2 represents together with the tuple between layer 1 and 2 – (4,2,2).

We can verify this below:

layer = 2
tag = 3
selected = {}
for k, v in all_s.items():
    if k[layer] == tag:
        # print(k, v)
        selected[k] = v

# if we want to transition at layer 2 to tag 3,
# the best tag to do it is from tag 2       
sorted(selected.items(), key=lambda x: x[1])[::-1][0]
# ((0, 2, 3, 4, 1), 3.383200004696846) - (tags, path score)

layer = 2
tag = 2
selected = {}
for k, v in all_s.items():
    if k[layer] == tag:
        # print(k, v)
        selected[k] = v

# if we want to transition at layer 2 to tag 2,
# the best tag to do it is from tag 4       
sorted(selected.items(), key=lambda x: x[1])[::-1][0]
# ((0, 4, 2, 4, 1), 2.2260000333189964) - (tags, path score)

Then the 3rd layer:

Figure by author (figure 4)
alpha2_0_val, alpha2_0_idx = torch.max(torch.tensor([(eval(f"alpha1_{i}_val") + transitions[i, 0] + emissions[2, 0]) for i in range(5)]).unsqueeze(0), 1)
alpha2_1_val, alpha2_1_idx = torch.max(torch.tensor([(eval(f"alpha1_{i}_val") + transitions[i, 1] + emissions[2, 1]) for i in range(5)]).unsqueeze(0), 1)
alpha2_2_val, alpha2_2_idx = torch.max(torch.tensor([(eval(f"alpha1_{i}_val") + transitions[i, 2] + emissions[2, 2]) for i in range(5)]).unsqueeze(0), 1)
alpha2_3_val, alpha2_3_idx = torch.max(torch.tensor([(eval(f"alpha1_{i}_val") + transitions[i, 3] + emissions[2, 3]) for i in range(5)]).unsqueeze(0), 1)
alpha2_4_val, alpha2_4_idx = torch.max(torch.tensor([(eval(f"alpha1_{i}_val") + transitions[i, 4] + emissions[2, 4]) for i in range(5)]).unsqueeze(0), 1)

print("*"*100)
for tag in range(5):
    print(f"Scores for tag {tag}")
    print(torch.tensor([(eval(f"alpha1_{i}_val") + transitions[i, tag] + emissions[2, tag]) for i in range(5)]).unsqueeze(0))
    print("Max index alpha : ", eval(f"alpha2_{tag}_idx").item())
    print()

print(alpha2_0_val, alpha2_1_val, alpha2_2_val, alpha2_3_val, alpha2_4_val)
print(alpha2_0_idx, alpha2_1_idx, alpha2_2_idx, alpha2_3_idx, alpha2_4_idx)
print("*"*100)

"""
Outputs:

****************************************************************************************************
Scores to tag 0
tensor([[-19998.9766,  -9998.6699,  -9998.9795,  -9998.0225,  -9998.3750]])
Max index alpha :  3

Scores to tag 1
tensor([[-1.0008e+04, -9.9982e+03, -4.6080e-01,  2.6964e+00,  2.6442e+00]])
Max index alpha :  3

Scores to tag 2
tensor([[-9.9984e+03, -9.9987e+03, -4.9899e+00, -2.0327e+00,  1.9151e+00]])
Max index alpha :  4

Scores to tag 3
tensor([[-9.9988e+03, -9.9986e+03,  1.9690e+00, -2.9738e+00, -4.3260e+00]])
Max index alpha :  2

Scores to tag 4
tensor([[-9.9978e+03, -9.9981e+03,  1.7260e+00,  2.8832e+00,  2.4310e+00]])
Max index alpha :  3

Maximum Values:
tensor([-9998.0225]) tensor([2.6964]) tensor([1.9151]) tensor([1.9690]) tensor([2.8832])
Corresponding Indices:
tensor([3]) tensor([3]) tensor([4]) tensor([2]) tensor([3])
****************************************************************************************************
"""

In the last 4th layer is where our recurrence stops and we can select the last node (before _STOPTAG) of the maximum sequence from where we can then backtrack to find the entire maximum sequence as I will show shortly:

alpha3_0_val, alpha3_0_idx = torch.max(torch.tensor([(eval(f"alpha2_{i}_val") + transitions[i, 1]) for i in range(5)]).unsqueeze(0), 1)
print(torch.tensor([(eval(f"alpha2_{i}_val") + transitions[i, 1]) for i in range(5)]).unsqueeze(0))
print(alpha3_0_val, alpha3_0_idx)
"""
tensor([[-1.0007e+04, -9.9973e+03, -8.4900e-02,  2.1690e+00,  3.3832e+00]])
tensor([3.3832]) tensor([4])
"""

Putting it all together:

dim_seq_len, dim_lbl = emissions.shape

backpointers = []
# Initialize the viterbi variables in log space
init_alphas = transitions[START_TAG ] + emissions[:1]

# alphas at step i holds the viterbi variables for step i-1
alphas = init_alphas
print("*" * 100)
print("Start Alphas : ", alphas)
print("*" * 100)
for l, emission in enumerate(emissions[1:], 1):
    bptrs_t = []  # holds the backpointers for this step
    viterbivars_t = []  # holds the viterbi variables for this step

    for next_tag in range(crf_mod.num_tags):
        # next_tag_var[i] holds the viterbi variable for tag i at the
        # previous step, plus the score of transitioning
        # from tag i to next_tag.
        # We don't include the emission scores here because the max
        # does not depend on them (we add them in below)
        next_tag_var = alphas + transitions[:, next_tag] + emission[next_tag]
        best_tag_score, best_tag_id = torch.max(next_tag_var, dim=-1)
        bptrs_t.append(best_tag_id)
        viterbivars_t.append(next_tag_var[0][best_tag_id].view(1))
        print(f"Scores for tag {next_tag} - {next_tag_var}n Max index alpha : {best_tag_id.item()}, max value alpha : {best_tag_score.item()}")

    # Now add in the emission scores, and assign alphas to the set
    # of viterbi variables we just computed
    alphas = (torch.cat(viterbivars_t)).view(1, -1)
    print("*" * 100)
    print(f"Alphas layer {l}: ", alphas)
    print("*" * 100)
    backpointers.append(bptrs_t)

# Transition to STOP_TAG
terminal_alphas = alphas + transitions[:, STOP_TAG]
# best tag at the end of the sequence (before the STOP_TAG)
best_tag_score, best_tag_id = torch.max(terminal_alphas, dim=-1)
print("*" * 100)
print("End Alphas : ", terminal_alphas)
print("*" * 100)

print(f"Max index alpha : {best_tag_id.item()}, max value alpha : {best_tag_score.item()}")

"""
Outputs:

****************************************************************************************************
Start Alphas :  tensor([[-1.0000e+04, -1.0705e+01,  1.5992e+00,  2.9480e-01,  7.1000e-01]],
       grad_fn=)
****************************************************************************************************
Scores for tag 0 - tensor([[-20000.4590, -10010.9453,  -9998.6406,  -9999.9453,  -9999.5303]],
       grad_fn=)
 Max index alpha : 2, max value alpha : -9998.640625
Scores for tag 1 - tensor([[-1.0009e+04, -1.0010e+04,  5.5700e-02,  9.5130e-01,  1.6665e+00]],
       grad_fn=)
 Max index alpha : 4, max value alpha : 1.6665000915527344
Scores for tag 2 - tensor([[-9.9993e+03, -1.0010e+04, -4.0544e+00, -3.3588e+00,  1.3564e+00]],
       grad_fn=)
 Max index alpha : 4, max value alpha : 1.3564000129699707
Scores for tag 3 - tensor([[-1.0000e+04, -1.0011e+04,  2.3136e+00, -4.8908e+00, -5.4756e+00]],
       grad_fn=)
 Max index alpha : 2, max value alpha : 2.3135998249053955
Scores for tag 4 - tensor([[-9.9994e+03, -1.0010e+04,  1.9614e+00,  8.5700e-01,  1.1722e+00]],
       grad_fn=)
 Max index alpha : 2, max value alpha : 1.961400032043457
****************************************************************************************************
Alphas layer 1:  tensor([[-9.9986e+03,  1.6665e+00,  1.3564e+00,  2.3136e+00,  1.9614e+00]],
       grad_fn=)
****************************************************************************************************
Scores for tag 0 - tensor([[-19998.9766,  -9998.6699,  -9998.9795,  -9998.0225,  -9998.3750]],
       grad_fn=)
 Max index alpha : 3, max value alpha : -9998.0224609375
Scores for tag 1 - tensor([[-1.0008e+04, -9.9982e+03, -4.6080e-01,  2.6964e+00,  2.6442e+00]],
       grad_fn=)
 Max index alpha : 3, max value alpha : 2.6963999271392822
Scores for tag 2 - tensor([[-9.9984e+03, -9.9987e+03, -4.9899e+00, -2.0327e+00,  1.9151e+00]],
       grad_fn=)
 Max index alpha : 4, max value alpha : 1.9150999784469604
Scores for tag 3 - tensor([[-9.9988e+03, -9.9986e+03,  1.9690e+00, -2.9738e+00, -4.3260e+00]],
       grad_fn=)
 Max index alpha : 2, max value alpha : 1.9690001010894775
Scores for tag 4 - tensor([[-9.9978e+03, -9.9981e+03,  1.7260e+00,  2.8832e+00,  2.4310e+00]],
       grad_fn=)
 Max index alpha : 3, max value alpha : 2.883199691772461
****************************************************************************************************
Alphas layer 2:  tensor([[-9.9980e+03,  2.6964e+00,  1.9151e+00,  1.9690e+00,  2.8832e+00]],
       grad_fn=)
****************************************************************************************************
****************************************************************************************************
End Alphas :  tensor([[-1.0007e+04, -9.9973e+03, -8.4900e-02,  2.1690e+00,  3.3832e+00]],
       grad_fn=)
****************************************************************************************************
Max index alpha : 4, max value alpha : 3.383199691772461

"""

Now we have all the information we can find the most likely tag sequence for our sentence using the backward algorithm:

Figure by author (figure 5)
# get the final most probable score and the final most probable tag 
# (tag 4 in the illustration)
best_path = [best_tag_id]

# Follow the back pointers to decode the best path.
for bptrs_t in reversed(backpointers):
  # best tag to follow given best tag from next layer
    best_tag_id = bptrs_t[best_tag_id]
    best_path.append(best_tag_id)

# reverse best path list
best_path.reverse()
best_path = torch.cat(best_path)

path_score, best_path

"""
Outputs:

(tensor([3.3832], grad_fn=), tensor([2, 3, 4]))

"""

First we find the the tag with maximum alpha score in the 3rd layer which is 4. Now we use previously found maximum alphas to find what tag we need to select from layer 2 given that the tag in 3rd layer is 4 – it's tag 3. Then, we do the same for previous layer and we select tag 2. Thus, the sequence with the highest probability score according to Viterbi algorithm is (0,2,3,4,1) which is the same that we found using the inefficient solution much quicker as we discarded all but 1 path at every intermediate alpha node.

Full LSTM-CRF code

import torch
import torch.nn as nn

IMPOSSIBLE = -1e4

class BiLSTM_CRF(nn.Module):

    def __init__(
        self, vocab_size, num_tags, start_tag, stop_tag,
        embedding_dim, hidden_dim
    ):
        super().__init__()

        self.num_tags = num_tags
        self.START_TAG = start_tag
        self.STOP_TAG = stop_tag

        # CRF parameters
        self.transitions = nn.Parameter(torch.randn(self.num_tags, self.num_tags))

        # These two statements enforce the constraint that we never transfer
        # to the start tag and we never transfer from the stop tag
        self.transitions.data[:, self.START_TAG] = IMPOSSIBLE
        self.transitions.data[self.STOP_TAG, :] = IMPOSSIBLE

        # LSTM parameters
        self.hidden_dim = hidden_dim
        self.word_embeds = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim // 2,
                            num_layers=1, bidirectional=True,
                            batch_first=False)

        # Maps the output of the LSTM into tag space.
        self.hidden2tag = nn.Linear(hidden_dim, self.num_tags)

        self.hidden = self.init_hidden()

    def init_hidden(self):
        return (torch.randn(2, 1, self.hidden_dim // 2),
                torch.randn(2, 1, self.hidden_dim // 2))

    def _get_emissions(self, sentence):
        self.hidden = self.init_hidden()
        embeds = self.word_embeds(sentence).view(len(sentence), 1, -1)
        lstm_out, self.hidden = self.lstm(embeds, self.hidden)
        lstm_out = lstm_out.view(len(sentence), self.hidden_dim)
        emissions = self.hidden2tag(lstm_out)
        return emissions

    def neg_log_likelihood(self, sentence, tags):
        emissions = self._get_emissions(sentence)
        forward_score = self._forward_alg(emissions)
        gold_score = self._score_sentence(emissions, tags)
        return forward_score - gold_score

    def _forward_alg(self, emissions):

        init_alphas = self.transitions[self.START_TAG] + emissions[0]

        # Wrap in a variable so that we will get automatic backprop
        alphas = init_alphas

        for emission in emissions[1:]:
            alphas_t = []  # The forward tensors at this timestep
            for next_tag in range(self.num_tags):
                # emission score for the next tag
                emit_score = emission[next_tag].view(1, -1).expand(1, self.num_tags)
                # transition score from any previous tag to the next tag
                trans_score = self.transitions[:, next_tag].view(1, -1)
                # combine current scores with previous alphas 
                # since alphas are in log space (see logsumexp below),
                # we add them instead of multiplying
                next_tag_var = alphas + trans_score + emit_score

                alphas_t.append(torch.logsumexp(next_tag_var, 1).view(1))

            alphas = torch.cat(alphas_t).view(1, -1)

        terminal_alphas = alphas + self.transitions[:, self.STOP_TAG]
        alphas = torch.logsumexp(terminal_alphas, 1)

        return alphas

    def _viterbi_decode(self, emissions):
        backpointers = []

        # Initialize the viterbi variables in log space
        init_alphas = self.transitions[self.START_TAG] + emissions[:1]

        # alphas at step i holds the viterbi variables for step i-1
        alphas = init_alphas
        for emission in emissions[1:]:
            bptrs_t = []  # holds the backpointers for this step
            viterbivars_t = []  # holds the viterbi variables for this step

            for next_tag in range(self.num_tags):
                # next_tag_var[i] holds the viterbi variable for tag i at the
                # previous step, plus the score of transitioning
                # from tag i to next_tag.
                # We don't include the emission scores here because the max
                # does not depend on them (we add them in below)
                next_tag_var = alphas + self.transitions[:, next_tag] + emission[next_tag]
                best_tag_score, best_tag_id = torch.max(next_tag_var, dim=-1)
                bptrs_t.append(best_tag_id)
                viterbivars_t.append(next_tag_var[0][best_tag_id].view(1))
            # Now add in the emission scores, and assign alphas to the set
            # of viterbi variables we just computed
            alphas = (torch.cat(viterbivars_t)).view(1, -1)
            backpointers.append(bptrs_t)

        # Transition to STOP_TAG
        terminal_alphas = alphas + self.transitions[:, self.STOP_TAG]
        best_tag_score, best_tag_id = torch.max(terminal_alphas, dim=-1)
        path_score = terminal_alphas[0][best_tag_id]

        # Follow the back pointers to decode the best path.
        # Append terminal tag 
        best_path = [best_tag_id]
        for bptrs_t in reversed(backpointers):
            best_tag_id = bptrs_t[best_tag_id]
            best_path.append(best_tag_id)

        best_path.reverse()
        best_path = torch.cat(best_path)

        return path_score, best_path

    def forward(self, sentence): 
        # Get the emission scores from the BiLSTM
        emissions = self._get_emissions(sentence)
        print(emissions)

        # Find the best path, given the emission scores.
        score, tag_seq = self._viterbi_decode(emissions)
        return score, tag_seq

    def _score_sentence(self, emissions, tags):
        # Gives the score of a provided tag sequence
        score = torch.zeros(1)
        tags = torch.cat([torch.tensor([self.START_TAG], dtype=torch.long), tags[0]])
        for i, emission in enumerate(emissions):
            score = score + self.transitions[tags[i], tags[i+1]] + emission[tags[i+1]]
        score = score + self.transitions[tags[-1], self.STOP_TAG]
        return score

def prepare_sequence(seq, to_ix):
    idxs = [to_ix[w] for w in seq]
    return torch.tensor(idxs, dtype=torch.long)

if __name__ == '__main__':

    START_TAG = ""
    STOP_TAG = ""
    EMBEDDING_DIM = 5
    HIDDEN_DIM = 4

    training_data = [
        (
            "Google Deepmind company".split(), 
            "B I O".split(),
        )
    ]

    word_to_ix = {START_TAG: 0, STOP_TAG: 1}
    for sentence, tags in training_data:
        for word in sentence:
            if word not in word_to_ix:
                word_to_ix[word] = len(word_to_ix)

    tag_to_ix = {START_TAG: 0, STOP_TAG: 1, 'B': 2, 'I': 3, 'O': 4}
    print(word_to_ix)

    crf_mod = BiLSTM_CRF(len(word_to_ix), len(tag_to_ix), tag_to_ix[START_TAG], tag_to_ix[STOP_TAG], 
                          embedding_dim=EMBEDDING_DIM, hidden_dim=HIDDEN_DIM)

    sentence, tags = training_data[0]
    sentence_in = prepare_sequence(sentence, word_to_ix)
    targets = torch.tensor([tag_to_ix[t] for t in tags], dtype=torch.long)

    torch.manual_seed(1)
    print(sentence_in, targets)

    score, tag_seq = crf_mod(sentence_in)
    print(score, tag_seq)
# For reproducibility of the example in the tutorial 
# set the matrices as follows:

crf_mod.word_embeds.weight.data = torch.tensor([[-1.5256, -0.7502, -0.6540, -1.6095, -0.1002],
        [-0.6092, -0.9798, -1.6091, -0.7121,  0.3037],
        [-0.7773, -0.2515, -0.2223,  1.6871,  0.2284],
        [ 0.4676, -0.6970, -1.1608,  0.6995,  0.1991],
        [ 0.8657,  0.2444, -0.6629,  0.8073, -1.8821],
        [-0.7765,  2.0242, -0.0865,  0.0981, -1.2150],
        [ 0.7312,  1.1718,  2.4070,  0.2786,  0.2468],
        [ 1.1843, -0.7282,  1.1633, -0.0091, -0.8425]])

crf_mod.hidden2tag.weight.data = torch.tensor([[ 0.1211, -0.3697,  0.4269, -0.1940],
    [ 0.3012,  0.0149, -0.0389, -0.0160],
    [ 0.0850,  0.2357,  0.0802,  0.1525],
    [-0.4498,  0.3643,  0.4359,  0.4133],
    [ 0.3696, -0.3608, -0.1854,  0.4409]])

crf_mod.hidden2tag.bias.data = torch.tensor([-0.3808,  0.4536, -0.3932, -0.3522,  0.2444])

crf_mod.lstm.weight_hh_l0.data = torch.tensor([[ 0.2411,  0.1267],
    [-0.3008, -0.2141],
    [ 0.6476, -0.1308],
    [ 0.3987,  0.3062],
    [-0.4571, -0.6013],
    [ 0.6787,  0.0369],
    [ 0.4847,  0.1465],
    [ 0.2274,  0.5282]])

crf_mod.lstm.weight_hh_l0_reverse.data = torch.tensor([[-0.5635, -0.3223],
    [-0.2166,  0.3024],
    [ 0.1292,  0.1747],
    [ 0.7058,  0.6892],
    [ 0.4823,  0.0225],
    [-0.4892,  0.5526],
    [-0.1768, -0.0572],
    [-0.6092, -0.1397]])

crf_mod.lstm.weight_ih_l0.data = torch.tensor([[-0.0331, -0.4720,  0.4306,  0.2195, -0.4571],
    [ 0.4593,  0.4293,  0.6271, -0.3964, -0.1164],
    [-0.0137,  0.1033, -0.5366, -0.5018,  0.3847],
    [-0.1658,  0.3454,  0.0403,  0.2322,  0.1555],
    [ 0.2571,  0.3505, -0.6549,  0.3559, -0.4972],
    [-0.5335,  0.0430, -0.1205,  0.4153, -0.4095],
    [-0.6286,  0.5146, -0.1049,  0.3977,  0.2273],
    [-0.5302,  0.1421,  0.1698, -0.4734, -0.3355]])

crf_mod.lstm.weight_ih_l0_reverse.data = torch.tensor([[ 0.4074,  0.6565, -0.4391,  0.1535,  0.6101],
    [ 0.4686,  0.4407,  0.5025,  0.4473,  0.1826],
    [-0.4835, -0.5938, -0.3240, -0.0823, -0.4334],
    [ 0.2587,  0.2188, -0.1601,  0.2718,  0.2285],
    [ 0.4317,  0.4762, -0.2395,  0.6909, -0.0817],
    [-0.0243, -0.6674, -0.4551, -0.4131, -0.3024],
    [ 0.5027, -0.2311, -0.5284,  0.2721,  0.2264],
    [ 0.4580, -0.3659,  0.1533, -0.2574, -0.1589]])

crf_mod.lstm.bias_hh_l0.data = torch.tensor([-0.4276,  0.0888,  0.7047, -0.4467,  0.3768, -0.3914, -0.6648, -0.1503])
crf_mod.lstm.bias_hh_l0_reverse.data = torch.tensor([-0.3321,  0.0400,  0.5119, -0.4974,  0.3321,  0.4543,  0.6917, -0.4949])
crf_mod.lstm.bias_ih_l0.data = torch.tensor([ 0.6705, -0.4692,  0.0884,  0.5277,  0.5123,  0.4393, -0.5117, -0.5092])
crf_mod.lstm.bias_ih_l0_reverse.data = torch.tensor([-0.4561,  0.6498, -0.6113, -0.5512, -0.0240, -0.3823,  0.2530, -0.2722])
crf_mod.transitions = torch.nn.Parameter(torch.tensor([[-1.0000e+04, -2.2910e-02,  7.2749e-02,  7.9209e-02,  9.4576e-02],
    [-1.0000e+04, -1.0000e+04, -1.0000e+04, -1.0000e+04, -1.0000e+04],
    [-1.0000e+04, -6.9932e-02,  6.5728e-02,  6.2673e-02, -7.9348e-02],
    [-1.0000e+04, -8.7596e-03,  4.2003e-02, -2.8962e-03, -5.0708e-02],
    [-1.0000e+04, -9.3993e-02, -7.0676e-02, -6.6569e-02,  8.2353e-02]]).float())

crf_mod.hidden = (torch.tensor([[ 0.6614,  0.2669],
                 [0.0617, 0.6213]]).unsqueeze(1), 
         torch.tensor([[-0.4519, -0.1661],
                 [-1.5228,  0.3817]]).unsqueeze(1))

CRFs Drawbacks and Conclusions

Even if in the past CRF-LSTM models have been widely used for sequence labeling tasks, they come with certain drawbacks when compared to more recent Transformer models. One important drawback is that CRF-LSTM are not good at modeling long-range dependencies between sequence elements and tend to work better with local context. This is not the case with Transformers that with their self-attention mechanisms are able to capture long-range dependencies excelling at modeling global context. Another problem with CRF-LSTM models is that they process sequences sequentially, which limits parallelization and can be slow for long sequences while Transformers process sequences in parallel and thus are normally faster. However, one important advantage of CRF-LSTM model is its interpretability as we can explore and make sense of transition and emission matrices while interpreting a Transformer model is more difficult.

References

https://hyperscience.com/blog/exploring-conditional-random-fields-for-nlp-applications/ https://pytorch.org/tutorials/beginner/nlp/advanced_tutorial.html https://people.cs.umass.edu/~mccallum/papers/crf-tutorial.pdf https://createmomo.github.io/2017/11/11/CRF-Layer-on-the-Top-of-BiLSTM-5/

Tags: AI Deep Dives Deep Learning Machine Learning Named Entity Recognition

Comment