Earn vs. Learn: Solving a Fishing Inspired Multi-Armed Bandit Problem

Author:Murphy  |  View: 22890  |  Time: 2025-03-22 21:42:25
Image by author

I aim for you to get the following from this article:

  1. A good understanding of what constitutes a multi-arm bandit problem
  2. An understanding of how to solve a multi-armed bandit problem (what to consider, examples of common algorithms on simulated data with Python code)

A few years ago, I read a great book called ‘Algorithms to Live By', written by Brian Christian and Tom Griffiths. The main concept of the book is that we face many of the same types of decision problems in our daily lives that data professionals face in their work. No problem do I see more instances of in my life than the multi-armed bandit problem. It shows up all over the place!

One place I see it is in my casual hobby of fishing. Quick disclaimer – I'm bad at fishing – but I still enjoy it. I often find myself wondering which lure I should use to catch as many fish as possible. While I'm using one lure, I'm always trying to decide if there is another one I should try that could catch more fish. This is a multi-armed bandit problem!

I thought it would be fun and informative to show various approaches to solving a fishing-inspired version of the multi-armed bandit problem. Below are the contents for the article:

  1. Defining the multi-armed bandit problem
  2. General strategies for solving the problem
  3. Testing various algorithms to solve the problem on a simulated fishing trip

What is the multi-armed bandit problem?

Apparently, an old nickname for slot machines is ‘one-armed bandit' (I'd never heard this before studying the multi-armed bandit problem). The ‘arm' is referring to the lever or handle that you pull to play the machine.

Imagine you are at a casino with 3 slot machines (or ‘one-armed bandits') in front of you. Each has a different, unknown payout rate. How do you optimize your total payout for the trip to the casino? At its heart, the problem is an earn vs. learn challenge. Trying different machines gives you knowledge, but you buy that knowledge with the opportunity cost of not using a machine the you are already familiar with. Say you play one machine for 5 minutes and find that on average it pays out $5 every 15 pulls, do you spend the rest of your afternoon at that machine? Or do you go to other machines that could have higher, or lower payouts? How do you balance earning (using the machine you know) with learning (looking for something better)? Below is the general set up for the multi-armed bandit problem:

general set up for multi-armed bandit problem – by author

I've played slot machines once in my life. I lost $20 in a couple of minutes and decided the expense-to-fun ratio was way too high and quit – I get the ‘bandit' part of the name! For me, fishing is a much more relatable instance of the problem. There are a lot of fishing lures out there. I can only use one lure at a time and each lure has an unknown probability of catching a fish. The challenge is that I want to catch as many fish as possible on my fishing trip, but I don't know which lure is best. This problem requires me to ‘learn as I go' – I have to balance learning about the lures and catching as many fish as I can with lures I've become familiar with on the trip.

While this article is oriented towards the arbitrary example of selecting fishing lures, I put a list of other common places the multi-armed bandit problem shows up in personal and professional scenarios (just to name a few):

Common examples of the multi-armed bandit problem – by author

General strategies for solving the multi-armed bandit problem

While the objective of the multi-armed bandit problem is always the same – get as much good stuff (whatever it is; money, fish, efficiencies etc.) as possible – the strategy to meet that objective can vary widely depending circumstance.

There are a few key questions to answer when solving the problem:

  • How much time do you have?
  • How many things are there to explore? How big is the solution space?
  • Does it cost to switch from one solution to another? If so, what is the relative cost?
  • How speculative are your goals? Are you in an "all-or-nothing" scenario?

The answers to these questions can help you understand how much you should explore (learn) and how much you should exploit (earn). Below is a graphic depicting the relationships between the variables in the questions and the level of exploration:

image by author

We'll talk through each one of these:

Amount of time

If you have more time, you should explore more because you will have more time to exploit afterwards. I like the restaurant example Christian and Griffiths use. If you have just moved to a new city, you should eat at a lot of restaurants to find some good ones that you can return to multiple times throughout the years. If it is the last day before you move out of a city, finding a new favorite restaurant won't do you much good – you can't come back! So, go to your favorite and enjoy!

Solution Space Size

If you have a very large solution space (meaning a lot of slot machines) you generally should explore more, because there is a higher possibility that there is a better solution out there. There is also just more to explore, so it will require more exploration to get a proportional amount of learning.

Imagine that there is a 1% probability that a slot machine has a very high average pay out. If you only have 10 slot machines, there is a ~9.5% chance that there is a high payout slot machine in the ten (1–0.99¹⁰). But, if you have 100 slot machines, there is a ~63% chance that there is a high payout machine (1–0.99¹⁰⁰). There is a higher probability that something really good is out there, so it makes more sense to look for it!

Switching Cost

Often, trying new things can cost more than staying at the current solution. Above I mentioned trying a different software – changing a software platform will cost time and money. Time to set it up on the system and time to learn how to use the new software. Cost in license fees and other service charges. In the fishing example, it takes time to remove one lure and put another lure on. If I change a lure after every cast, I will not be able to make as many casts in a trip compared to if I used the same lure the whole time.

As the cost of switching increases, you should explore less. This is because the potential advantages of finding a better solution are offset by the guaranteed disadvantage of paying every time you search for a better solution.

Speculation/All or Nothing

Another important consideration is your overall strategy. If you are very speculative and are looking for an ‘all or nothing' type payout, you shouldn't do a lot of exploration. You will want to leave more up to chance. If you find a great solution early on, by chance, you won't waste opportunity looking for other solutions.

If you are in a position where you are competing with other people/entities that are trying to solve the same multi-armed bandit problem, and second place to you is equivalent to being last (i.e. winning is binary, you win or you lose) you don't want to spend a lot of time exploring. You want to roll the dice and hope that you get lucky.

This strategy is similar to trying to call a lot of the upsets in a March Madness bracket. If you do this 1000 times, you will lose pretty bad most of the time – your average ranking will be lower than if you took a more conservative approach to your predictions. But, the proportion of times that you finish in first place will likely be higher than if you took the conservative approach (assuming you are competing against a large number of people).


Solving the fishing inspired problem

Alright, enough hypothetical talk — let's simulate the fishing problem and start spinning up some algorithms to solve it!

The simulation is pretty easy, you have a tackle box with multiple lures in it – each lure has a probability of catching a fish (that we set). We simulate a fishing trip by setting a fixed number of casts (e.g. in this fishing trip we have time for 300 casts). We then run the fishing trip with the tackle box through various algorithms to see which strategies give us the best results!

The fishing simulation I use here is pretty simple (I didn't include the code in the article, but you can find it in the repo here). If you are interested in other applications of simulating data, I wrote a whole series on it! Below is the link to the last article, which has links to the other articles in the series.

Simulated Data, Real Learnings : Simulating Systems

Here's our tackle box — the corresponding number is the probability of catching a fish on an individual cast with the specific lure:

lures = {'crank_bait' : 0.2,
         'spinner' : 0.01,
         'soft_plastic' : 0.15,
         'rooster' : 0.10,
         'jerk_bait' : 0.05}

We'll use this ‘tackle box' to simulate a fishing trip. We'll try different lure switching strategies to balance learning how good a lure is against using what we think is the best lure based on our experience up to that point.

Note: the algorithms (with the exception of the ‘optimal' algorithm) can't see the tackle box payout numbers, they are trying to learn these numbers and catch as many fish as possible at the same time.

Here's a quick outline of the algorithms that we are going to run the tackle box through:

run down of multi-armed bandit solving algorithms – by author

In this article, I'll share snippets of code that execute each algorithm to solve the problem. I won't share the full code base for brevity, but you can look at it/clone it in the linked repo here.

Optimal Solution

This is how many fish you would catch if you knew what the best lure was and you used it for the whole trip. This solution is not feasible in ‘real life', but serves as a ceiling to compare the other algorithms to.

The code is very simple, select the highest lure and simulate n number of casts with it, then record how many simulated fish you caught!

def optimal_strategy(self):

    '''
        Serves as the benchmark to compare all strategies to.
        It simply selects the highest probability (which wouldn't be 
        known if practices) and only casts with that.

        inputs: None
        outputs: 
            cast_success (list) : list of 1's and 0's representing if a cast
                                  caught a fish or not

    '''

    lure_with_max_payout = max(self.lures, key=lambda key: self.lures[key])
    max_payout = self.lures[lure_with_max_payout]

    cast_success = self.simulate_cast(max_payout, n = self.casts)

    return cast_success

Random Solution

The optimal solution serves as a ceiling for how good an algorithm can get. The random solution algorithm serves as a floor. This strategy randomly selects lures for each cast and completely ignores how good/bad the fishing has been with each lure. Good strategies must beat the random strategy. If a strategy is not better than random, it is essentially worthless!

Here is the code for the random strategy.

def random(self):

    '''
        Runs a fishing trip where each cast uses a random
        lure.  Intended to be used as a floor for algorithm
        comparison.
    '''

    cast_success = []

    for cast in range(0, self.casts):
        random_lure, catch_prob = self.random_start()

        temp_cast_bool = self.simulate_cast(catch_prob)

        cast_success.append(temp_cast_bool)

    return cast_success

One-Round Learn

Now that we have a ceiling and a floor for our analysis, let's try a real candidate algorithm for solving the problem! The "one round learn" algorithm tries each lure for a user-specified number of casts. After it has gone through each lure, it picks the lure that caught the most fish and uses that lure for the remainder of the casts. In other words, it samples each lure n times and then uses the best lure (based on the sample results) for remaining casts on the trip.

To get the best results from this algorithm, we need to pick the optimal number of casts for each lure in the initial sampling phase. To do this, I ran the algorithm for multiple sample sizes. I picked the sample size with the highest average number of catches. In this specific example, the optimal sample size was 4.

Before I show the code, let's look at a graphical representation of how the algorithm works:

One-round learn algorithm graphic— image by author

Hopefully the graphic is clear, here's the code to run the one round learn algorithm:

def one_round_learn(self, num_tests, output_for_elminate_n = False,
                  output_for_elminate_dict = {},
                  output_for_greedy_eps = False):

  '''
      Limits learn to one round, where each lure is used for
      num_tests casts and then the lure with the highest yeild 
      is used for all remaining casts

      inputs:
          num_test (int) : number of casts for each lure in the first round
          output_for_elminate_n (bool; False) : indicates if output should be 
                                                set up for use in eliminate n 
                                                algorithm
          output_for_eliminate_dict (dict)    : dict with boolean list of 
                                                cast results for each lure
          output_for_greed_eps (bool; False)  : indicates if output that works
                                                for greedy epsilon algorithm 
                                                should be returned

      Outputs:
          catch_list (list) : list of bools that represent the results
                              of each cast in order.
  '''

  test_payouts = {}
  test_payouts_lists = {}
  cast_count = 0
  catch_bool_list = []

  # test each lure to decide which one to use for the
  # rest of the casts

  # use input dictionary if one is given
  if len(output_for_elminate_dict) == 0:

      lure_dict = self.lures

  else:
      lure_dict = copy(output_for_elminate_dict)

  # loop through lures
  for lure in lure_dict:

      lure_pct = lure_dict[lure]
      catch_list = []

      for test_cast in range(0, num_tests):

          cast_count += 1

          catch_bool = self.simulate_cast(lure_pct)
          catch_list.append(catch_bool)
          catch_bool_list.append(catch_bool)

      # now that test for this lure is done, count successes
      catches = np.sum(catch_list)
      test_payouts[lure] = catches
      test_payouts_lists[lure] = catch_list

  # return the results of the one-round to feed into the greed
  # epsilon algorithm if indicated by user
  if output_for_greedy_eps:
      return test_payouts_lists, catch_bool_list

  # get highest catching lure
  best_lure = max(test_payouts, key=test_payouts.get)
  best_lure_pct = self.lures[best_lure]

  if output_for_elminate_n:
      return test_payouts, catches

  # now use remaining casts with the best lure from tests
  for remaining_cast in range(0, self.casts - cast_count):

      catch_bool_best = self.simulate_cast(best_lure_pct)
      catch_list.append(catch_bool_best)

  catch_success = np.sum(catch_list)

  return catch_list

Iterative Elimination

The iterative elimination algorithm is essentially an extension of the one round learn algorithm. For the iterative elimination, each lure is sampled for a specified number of casts, the n lures with the fewest catches are removed from the ‘tackle box.' Then the process is run again, this time with the smaller tackle box and more lures are eliminated. This process is followed until there is only one lure remaining. Once there is only one lure left, that lure is used exclusively for the remaining casts.

There are two inputs that are needed for this algorithm: (1) the number of lures that are eliminated in each round and (2) the number of casts for each lure for each round. I tested multiple combinations and found that the the best combination was to eliminate 2 lures in each round and to do 3 casts to sample the lures each round.

Once again, before showing the code, let's take a look at a graphical representation of how the iterative elimination algorithm works. The graphic is a little busier since this is a more complicate algorithm than the one-round learn!

Iterative elimination algorithm graphic— image by author

And here is the code to run the algorithm:

def eliminate_n(self, n, num_tests):

      '''

          inputs:
              n (int) : number of lures to eliminate for each
                        iteration, higher is more aggressive as
                        more lures will be eliminated more quickly

      '''

      # finds top n performers and rotates through them
      # a few time until it finds the best overall
      lure_dict, catches = self.one_round_learn(num_tests,
                                                output_for_elminate_n = True)

      cast_count = num_tests*len(lure_dict)

      while len(lure_dict) > 1 and cast_count <= self.casts:

          # keep track of how many casts have been used
          cast_count += 1

          # Sort the dictionary by values in descending order
          sorted_dict = dict(sorted(lure_dict.items(), key=lambda item: item[1], reverse=True))

          # eliminate n elements
          n_keep = len(lure_dict) - n

          # if there are more than n elements, just 
          # select the highest valued lure and use that 
          # for the rest of the casts
          if n_keep < n:
              best_lure = list(sorted_dict.items())[0]
              break
          else:
              lure_dict, temp_catches = self.one_round_learn(num_tests, output_for_elminate_n = True, 
                                                             output_for_elminate_dict = lure_dict)

              catches += temp_catches

              lure_dict = dict(list(sorted_dict.items())[:n_keep])         

      # now use remaining casts with the best lure from tests
      best_lure_pct = self.lures[best_lure[0]]

      for remaining_cast in range(0, self.casts - cast_count):

          catch_bool_best = self.simulate_cast(best_lure_pct)
          catches += np.sum(catch_bool_best)

      catch_success = catches

      return catch_success

Greedy Epsilon

This algorithm's primary strategy is to generally use the lure with the highest historical catch rate (hence the ‘greedy' in the name) and randomly try other lures with a probability determined by epsilon which is a user input. This algorithm requires the user to select a value for epsilon, the higher the epsilon value, the higher the probability of moving away from the lure with the best catch rate. Or in other words, higher values of epsilon prioritize learning, lower values prioritize earning.

After trying multiple values for epsilon, I found that for this problem, the best epsilon rate was 0.10.

Let's look at the graphic of how the greedy epsilon algorithm works.

Greedy Epsilon Algorithm graphic – image by author

Here's the code for the greedy epsilon algorithm:

def epsilon_greedy(self, epsilon, rand_start_casts = 3):

          '''
              Starts with a random lure and gathers data on it,
              with a probability of epsilon at each cast, change the lure
              and gather data.  With a probability of 1-epsilon go back to 
              the lure with the highest payout

              inputs:
                  epsilon (float) : probability used to randomly switch lures
                  schedule (list) : factors to adjust epsilon based on 
                                  total number of casts - can be used 
                                  to lower the probability of switching later
                                  in trip - default = [1]

              outputs:
                  cast_success (list) : list of 1's and 0's representing if 
                                        a cast caught a fish or not

          '''

          # create a dictionary to keep track of observed catches
          obs_catch_dict, cast_success = self.one_round_learn(rand_start_casts, output_for_greedy_eps = True)

          cast_count = rand_start_casts*len(obs_catch_dict)

          # get highest catching lure
          curr_lure = max(obs_catch_dict, key=obs_catch_dict.get)
          curr_lure_pct = self.lures[curr_lure] 

          # for the remaining casts, use the epsilon-greedy approach
          for _ in range(cast_count, self.casts):

              switch_prob = np.random.uniform()

              # random switch
              if epsilon >= switch_prob:

                  # switch lures randomly
                  curr_lure, curr_lure_pct = self.random_start(exclusions = [curr_lure])

                  # simulate cast with switched lure
                  temp_cast_bool = self.simulate_cast(curr_lure_pct)

                  cast_success.append(temp_cast_bool)

                  # add results from current cast to dictionary of catches
                  temp_catch_list = obs_catch_dict[curr_lure]
                  temp_catch_list.append(temp_cast_bool)
                  obs_catch_dict[curr_lure] = temp_catch_list

              # switch to lure with highest payout
              else:

                  curr_lure = max(obs_catch_dict, key=obs_catch_dict.get)
                  curr_lure_pct = self.lures[curr_lure] 

                  # simulate cast with switched lure
                  temp_cast_bool = self.simulate_cast(curr_lure_pct)

                  # add results from current cast to dictionary of catches
                  temp_catch_list = obs_catch_dict[curr_lure]
                  temp_catch_list.append(temp_cast_bool)
                  obs_catch_dict[curr_lure] = temp_catch_list

                  cast_success.append(temp_cast_bool)

          return cast_success 

Upper Confidence Bound

This is a cool algorithm that attempts to set an upper limit for what the catch rate of a lure could be. This upper limit is based on the level of knowledge we have about the lure. If we don't have a lot of casts for a lure, we have a small sample size. Since we don't know a lot about the lure, there is a possibility that it has a catch rate that is much higher than the average (think the law of large numbers). So, for lures with few casts, we give a higher ‘upper confidence bound' or ‘UCB' for short. We fish with the lure with the highest UCB for each cast. As we fish more with a lure, the UCB will shrink towards the average. As a specific lure's UCB shrinks, if it gets lower than the UCB of another lure, we will switch to that lure. This process continues until we run out of casts. Of course, as we get more casts in the trip, the amount of switching will decrease – because our UCB's will start to stabilize.

The UCB algorithm takes an input that can be used to decrease or increase the UCB. From some experimentation, I found that the value of 0.1 was the ideal multiplier for our fishing simulation.

Of course, before getting into the code, we'll take a look at a visualization of the algorithm's process:

UCB algorithm graphic- image by author

And here is the code to execute the algorithm:

def ucb(self, bound_multiplier):

      '''
          Runs the upper confidence bound algorithm. 
          This algorithm adds a bonus to the observed 
          catch rate based on the number of observations.
          Higher observation count gives smaller bonus.

          inputs:
              bound_multiplier (float) : controls the size of the bonus
                                         larger -> larger bonus

          outputs:
              cast_success (list) : list of 1's and 0's representing if a cast
                                    caught a fish or not

      '''

      ucb_dict = {}
      # for each lure, create a list dicitonary of lists
      # that have UCB metrics
      for arm in self.lures:
          ucb_dict[arm] = {'success' : [],
                           'ucb'     : 1}

      cast_success = []

      # cast using UCB algorithm
      for cast in range(1, self.casts + 1):

          highest_ucb_lure = self.find_highest_ucb(ucb_dict)

          catch_bool = self.simulate_cast(self.lures[highest_ucb_lure])
          cast_success.append(catch_bool)

          # update ucb_dict for cast
          successes = ucb_dict[highest_ucb_lure]['success']
          successes.append(catch_bool)

          prob_estimate = np.mean(successes)/len(successes)
          ucb_estimate = prob_estimate + bound_multiplier*(1/len(successes))

          ucb_dict[highest_ucb_lure] = {'success' : successes,
                                        'ucb' : ucb_estimate}

      return cast_success

  def find_highest_ucb(self, ucb_dict):

      '''
          Called by ucb method. Finds the lure with the
          highest upper confidence bound.

          inputs:
              ucb_dict (dict) : dictionary with lure as keys and
                                upper confidence bound as element
          output:
              highest_lure (str) : name of lure with highest ucb
      '''

      highest_estimate = -np.inf

      for lure in ucb_dict:

          temp_estimate = ucb_dict[lure]['ucb']

          if temp_estimate > highest_estimate:
              highest_estimate = temp_estimate
              highest_lure = lure

      return highest_lure

Comparison of algorithm performances

Now that we have an understanding of the algorithms we will be testing, let's take a look at how they stack up against each other. I ran 500 simulations for each strategy to create a distribution of catches for each algorithm. The box plots for the algorithms are below:

image by author

There are two main things to consider when looking at this plot, (1) the average number of fish caught and (2) the variance of the catches. While the one-round learner and iterative elimination have the highest average catches, the spreads of the catches are huge! This is because we learn quickly, which often pays off, but sometimes we pick a lure too fast and fish with a really bad one for the rest of the trip. The greedy epsilon algorithm doesn't have a very impressive average catch amount, but does have a smaller variance – which in most cases is what we are looking for. I would say that the winner is the UCB algorithm. While it doesn't have the highest average, it has a good balance of a higher average with a much tighter range.

A quick note here: as I said, we don't typically want a wide range of variability – this leads to some great fishing trips and some terrible ones. In some cases however, more variance is better as I mentioned in the strategy section of this article. Imagine I was in a fishing tournament where the first place winner won a brand new Toyota Tundra and the second place winner got a coupon for a free ice cream cone. My best strategy would likely be to choose one-round learn or iterative elimination. This would give me the best chance of getting lucky and beating out the competition. With a more conservative strategy my placement in the tournament would be more predictable, but I don't really care about that, I want the truck!

Conclusion

The various algorithms we looked at all have different ways of handling the earn vs. learn challenge for the multi-armed bandit problem. The final results are not generalizable to all multi-armed bandit problems. Different payout rates (different solution spaces) will likely have different ideal algorithms and different parameter levels.

We also looked at a very simple version of the multi-armed bandit problem. We only had five arms with fairly different payouts. The problems get more difficult when there are more arms and/or the payouts are pretty similar (it is more difficult to differentiate the payouts). We also ignored the very real possibility that payouts could change during the problem solving, and the likely real world issue of switching costs.

Having said that, I hope you are coming away with a strong understanding what a multi-armed bandit problem is and some common approaches to solving them.

Tags: Data Analysis Data Analytics Data Science Deep Dives Optimization

Comment