Applying $k$-nearest neighbor classifier to the Pokemon data set

by EC Corro | J Monje | JN Obrero | M Romero | A White

In [1]:
from IPython.display import HTML
HTML('''<script>
code_show=true;
function code_toggle() {
 if (code_show){
 $('div.input').hide();
 } else {
 $('div.input').show();
 }
 code_show = !code_show
}
$( document ).ready(code_toggle);
</script>
<form action="javascript:code_toggle()"><input type="submit" value="Click here to toggle on/off the raw code."></form>''')
Out[1]:
In [2]:
import numpy as np
import pandas as pd
import seaborn as sns
from collections import Counter
from matplotlib import pyplot as plt
from sklearn.neighbors import KNeighborsClassifier
from sklearn import datasets
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import train_test_split

knn = KNeighborsClassifier(n_neighbors=1)
%matplotlib inline

Pokemon Dataset

image Source: https://thesoundtrackonline.com/2018/03/10/ranked-all-seven-pokemon-generations/

1. Problems:

We want to explore whether the anime and game versions of Pokemon are consistent with each other. Specifically, we seek to explore the following questions using k-NN classification algorithm.

  1. Can we classify all Pokemon according to their primary type (Type 1) using their basic stats alone?
  2. Which primary types of Pokemon can be distinguished easily in terms of their basic stats?
  3. Can we classify legendary and non-legendary Pokemon using their basic stats alone?

2. Highlights:

  1. It is difficult to classify ALL Pokemon by type based on their stats alone.

  2. Some pairs of Pokemon types can be classified accurately based on their stats.

  3. Legendaries and non-legendaries can be classified based on their stats alone.

3. Data Discussion:

The Pokemon data set used in this study was taken from Kaggle. The attributes in this data set were sourced from various websites pokemon.com, pokemondb and bulbapedia. It should be noted that the basic stats (HP, attack, defense, etc) here are based from Pokemon games, not the show.

The data set has 800 data points, with the following columns.

  • #: ID number of each Pokemon
  • Name: name of each Pokemon
  • Type 1: each Pokemon has a type that determines their strengths and weaknesses
  • Type 2: some Pokemons have two types
  • Total: sum of the basic stats
  • HP: health of a pokemon which determines the damage it can withstand
  • Attack: the base modifier for normal attacks (eg. scratch, punch)
  • Defense: the base damage resistance against normal attacks
  • SP Atk: special attack, the base modifier for special attacks (e.g. fire blast, bubble beam)
  • SP Def: the base damage resistance against special attacks
  • Speed: determines which Pokemon attacks first each round
  • Generation: grouping of Pokemon
  • Legendary: extremely rare, powerful and mythical Pokemon

Throughout this notebook, we will be using the term basic stats which collectively refers to the features HP, Attack, Defense, SP Atk, SP Def and Speed.

To compare how consistent these game stats are from the anime version, we referred to this website.

In [3]:
df = pd.read_csv('pokemon_dataset.csv')
df.head()
Out[3]:
# Name Type 1 Type 2 Total HP Attack Defense Sp. Atk Sp. Def Speed Generation Legendary
0 1 Bulbasaur Grass Poison 318 45 49 49 65 65 45 1 False
1 2 Ivysaur Grass Poison 405 60 62 63 80 80 60 1 False
2 3 Venusaur Grass Poison 525 80 82 83 100 100 80 1 False
3 3 VenusaurMega Venusaur Grass Poison 625 80 100 123 122 120 80 1 False
4 4 Charmander Fire NaN 309 39 52 43 60 50 65 1 False

The target column in this data set is the primary type (Type 1) of the Pokemon, that is, we will attempt to classify the primary type of all Pokemon by just doing some analysis on their basic stats. Listed below are the primary types and the correspnding counts.

In [4]:
df['Type 1'].value_counts().reset_index()
Out[4]:
index Type 1
0 Water 112
1 Normal 98
2 Grass 70
3 Bug 69
4 Psychic 57
5 Fire 52
6 Electric 44
7 Rock 44
8 Ground 32
9 Dragon 32
10 Ghost 32
11 Dark 31
12 Poison 28
13 Fighting 27
14 Steel 27
15 Ice 24
16 Fairy 17
17 Flying 4

4. $k$-Nearest Neighbor Classification Implementation:

4.1 Considering ALL types in the classification:

First, implement the k-NN classification algorithm to classify all Pokemon according to their primary types using their basic stats.

In [5]:
df_features = df[['HP', 'Attack', 'Defense', 'Sp. Atk', 'Sp. Def', 'Speed']]
df_target = df['Type 1']
In [6]:
training_accuracy = []
test_accuracy = []
trials = range(50)
for trial in trials:
    (X_train, X_test, 
     y_train, y_test) = train_test_split(df_features, 
                                         df_target,
                                         test_size=0.25,
                                         random_state=trial)

    # Set n for n_neighbors from 1 to 39
    neighbors_settings = range(1, 40)

    for n_neighbors in neighbors_settings:
        # Build model
        clf = KNeighborsClassifier(n_neighbors=n_neighbors)
        clf.fit(X_train, y_train)
        # Record training accuracy for one trial
        training_accuracy.append(clf.score(X_train, y_train))
        # Record generalization trial for one trial
        test_accuracy.append(clf.score(X_test, y_test))

#Reshaping accuracies in such a way that one row is one trial
training_accuracy = (np.array(training_accuracy)
                       .reshape(len(trials), len(neighbors_settings)))
#Calculate mean and standard deviation per column
training_err = np.std(training_accuracy, axis=0)
training_accuracy = np.mean(training_accuracy, axis=0)

#Reshape accuracies in such a way that one row is one trial
test_accuracy = (np.array(test_accuracy)
                   .reshape(len(trials), len(neighbors_settings)))
#Calculate mean and standard deviation per column
test_err = np.std(test_accuracy, axis=0)
test_accuracy = np.mean(test_accuracy, axis=0)

# Graph results
plt.figure(figsize=(10,5))
plt.errorbar(neighbors_settings, training_accuracy, yerr=training_err, 
             label="training accuracy")
plt.errorbar(neighbors_settings, test_accuracy, yerr=test_err, 
             label="test accuracy")
plt.title("Classifying Pokemon (ALL types)", fontsize=18)
plt.ylabel("Accuracy")
plt.xlabel("n_neighbors")
plt.xticks(range(1,len(neighbors_settings) + 1, 5))
plt.legend();

4.2 Considering TWO types in the classification:

In the Pokemon story, it is known that the effects of attack and defense of each Pokemon type vary depending on the type of their opponents (Click for more details). Hence, in this section, some Pokemon types are paired up to check if they can be accurately classified using k-NN classification algorithm.

In [7]:
pairs = [['Electric', 'Ground'],['Fighting', 'Steel'],['Psychic', 'Fighting']]
fig, axes = plt.subplots(1,3, figsize=(20,4))
for pair, ax in zip(pairs, axes):
    df_rows = df.loc[df['Type 1'].isin(pair)]
    df_features = df_rows[['HP', 'Attack', 'Defense', 'Sp. Atk', 'Sp. Def', 
                           'Speed']]
    df_target = df_rows['Type 1']
    
    training_accuracy = []
    test_accuracy = []
    trials = range(10)
    for trial in trials:
        (X_train, X_test, 
         y_train, y_test) = train_test_split(df_features, 
                                             df_target,
                                             test_size=0.25,
                                             random_state=trial)

        # Set n for n_neighbors from 1 to 49
        neighbors_settings = range(1, 40)

        for n_neighbors in neighbors_settings:
            # Build model
            clf = KNeighborsClassifier(n_neighbors=n_neighbors)
            clf.fit(X_train, y_train)
            # Record training accuracy for one trial
            training_accuracy.append(clf.score(X_train, y_train))
            # Record generalization trial for one trial
            test_accuracy.append(clf.score(X_test, y_test))

    #Reshaping accuracies in such a way that one row is one trial
    training_accuracy = (np.array(training_accuracy)
                           .reshape(len(trials), len(neighbors_settings)))
    #Calculate mean and standard deviation per column
    training_err = np.std(training_accuracy, axis=0)
    training_accuracy = np.mean(training_accuracy, axis=0)

    #Reshape accuracies in such a way that one row is one trial
    test_accuracy = (np.array(test_accuracy)
                       .reshape(len(trials), len(neighbors_settings)))
    #Calculate mean and standard deviation per column
    test_err = np.std(test_accuracy, axis=0)
    test_accuracy = np.mean(test_accuracy, axis=0)

    # Graph results
    #plt.figure(figsize=(10,5))
    ax.errorbar(neighbors_settings, training_accuracy, yerr=training_err, 
                 label="training accuracy")
    ax.errorbar(neighbors_settings, test_accuracy, yerr=test_err, 
                 label="test accuracy")
    ax.set_ylabel("Accuracy")
    ax.set_xlabel("n_neighbors")
    ax.set_xticks(range(1,len(neighbors_settings) + 1), 10)
    ax.set_title("{} vs {}".format(pair[0], pair[1], fontsize=18))
    ax.legend()    

4.3 Considering Legendary Types

Maybe you once asked the question, "Are legendary Pokemon really way stronger than non-legendary ones?" This question pushed us to check whether we can accurately distinguish legendary Pokemon just by looking at their basic stats.

In [8]:
df.groupby('Legendary')['Generation'].count()
Out[8]:
Legendary
False    735
True      65
Name: Generation, dtype: int64
In [9]:
fig, ax = plt.subplots(1,2, figsize=(20,5.8))

# -------------- LEGENDARY VS NON-LEGENDARY (UNEQUAL SIZES)---------------------
df_features = df[['HP', 'Attack', 'Defense', 'Sp. Atk', 'Sp. Def', 
                       'Speed']]
df_target = df['Legendary']

training_accuracy = []
test_accuracy = []
trials = range(10)
for trial in trials:
    (X_train, X_test, 
     y_train, y_test) = train_test_split(df_features, 
                                         df_target,
                                         test_size=0.25,
                                         random_state=trial)

    # Set n for n_neighbors from 1 to 49
    neighbors_settings = range(1, 40)

    for n_neighbors in neighbors_settings:
        # Build model
        clf = KNeighborsClassifier(n_neighbors=n_neighbors)
        clf.fit(X_train, y_train)
        # Record training accuracy for one trial
        training_accuracy.append(clf.score(X_train, y_train))
        # Record generalization trial for one trial
        test_accuracy.append(clf.score(X_test, y_test))

#Reshaping accuracies in such a way that one row is one trial
training_accuracy = (np.array(training_accuracy)
                       .reshape(len(trials), len(neighbors_settings)))
#Calculate mean and standard deviation per column
training_err = np.std(training_accuracy, axis=0)
training_accuracy = np.mean(training_accuracy, axis=0)

#Reshape accuracies in such a way that one row is one trial
test_accuracy = (np.array(test_accuracy)
                   .reshape(len(trials), len(neighbors_settings)))
#Calculate mean and standard deviation per column
test_err = np.std(test_accuracy, axis=0)
test_accuracy = np.mean(test_accuracy, axis=0)

# Graph results
#plt.figure(figsize=(10,5))
ax[0].errorbar(neighbors_settings, training_accuracy, yerr=training_err, 
             label="training accuracy")
ax[0].errorbar(neighbors_settings, test_accuracy, yerr=test_err, 
             label="test accuracy")
ax[0].set_ylabel("Accuracy")
ax[0].set_xlabel("n_neighbors")
ax[0].set_xticks(range(1,len(neighbors_settings) + 1), 10)
ax[0].set_title("Legendary vs Non-legendary (Unequal Sizes)", fontsize=18)
ax[0].legend()

# -------------- LEGENDARY VS NON-LEGENDARY (EQUAL SIZES)---------------------

df_rows = df[df['Legendary'] == False].sample(n=65)
df_rows = df_rows.append(df[df['Legendary'] == True])

df_features = df_rows[['HP', 'Attack', 'Defense', 'Sp. Atk', 'Sp. Def', 
                       'Speed']]
df_target = df_rows['Legendary']

training_accuracy = []
test_accuracy = []
trials = range(10)
for trial in trials:
    (X_train, X_test, 
     y_train, y_test) = train_test_split(df_features, 
                                         df_target,
                                         test_size=0.25,
                                         random_state=trial)

    # Set n for n_neighbors from 1 to 49
    neighbors_settings = range(1, 40)

    for n_neighbors in neighbors_settings:
        # Build model
        clf = KNeighborsClassifier(n_neighbors=n_neighbors)
        clf.fit(X_train, y_train)
        # Record training accuracy for one trial
        training_accuracy.append(clf.score(X_train, y_train))
        # Record generalization trial for one trial
        test_accuracy.append(clf.score(X_test, y_test))

#Reshaping accuracies in such a way that one row is one trial
training_accuracy = (np.array(training_accuracy)
                       .reshape(len(trials), len(neighbors_settings)))
#Calculate mean and standard deviation per column
training_err = np.std(training_accuracy, axis=0)
training_accuracy = np.mean(training_accuracy, axis=0)

#Reshape accuracies in such a way that one row is one trial
test_accuracy = (np.array(test_accuracy)
                   .reshape(len(trials), len(neighbors_settings)))
#Calculate mean and standard deviation per column
test_err = np.std(test_accuracy, axis=0)
test_accuracy = np.mean(test_accuracy, axis=0)

# Graph results
#plt.figure(figsize=(10,5))
ax[1].errorbar(neighbors_settings, training_accuracy, yerr=training_err, 
             label="training accuracy")
ax[1].errorbar(neighbors_settings, test_accuracy, yerr=test_err, 
             label="test accuracy")
ax[1].set_ylabel("Accuracy")
ax[1].set_xlabel("n_neighbors")
ax[1].set_xticks(range(1,len(neighbors_settings) + 1), 10)
ax[1].set_title("Legendary vs Non-legendary (Equal Sizes)", fontsize=18)
ax[1].legend()    
Out[9]:
<matplotlib.legend.Legend at 0x1e6f9376b08>

5. Discussion of Results/Insights:

5.1 It is difficult to classify ALL Pokemons by type based on their stats alone.

Result of the k-NN classification algorithm in 4.1 attempting to classify all Pokemon according to their primary type shows only an accuracy of around 20-25%. This suggests that all primary types of Pokemon cannot be categorized just based on their basic stats. One possible reason for this is the dispersion of the basic stats of Pokemon within the same type. This means that a set of stats of a Pokemon may not be distinguished from those of other Pokemon types if all Pokemon are to be considered. One can argue that in the same type there are strong Pokemon and weak ones like Charizard and Charmander, respectively.

In addition, there are too many classifications and you cannot distinguish one set of basic stats over the other if you compare them all at once. Some pairs of types though are easily differentiated and classified, while some pairs yield low accuracy in terms of classification.

5.2 Some pairs of Pokemon types can be classified accurately based on their stats.

Even if the model have problems classifying multiple types based on basic stats alone, there are pairs of Pokemon types which yield high accuracy using k-NN classification. In other words, if only two Pokemon types are considered, k-NN classification algorithm classifies them into types accurately based on their basic stats. This makes sense as some types of Pokemon are known to be stronger in certain stats.

To expound on this, consider the pairs of Pokemon types considered in the k-NN classification implementation in 4.2.

  1. Electric vs Ground [$k=11$, Accuracy: 82%]: Ground Pokemon are immune to electric-type attacks.
  2. Fighting vs Steel [$k=9$, Accuracy: 89%]: Fighting type attacks are super-effective (x2 attack effect) against steel Pokemon.
  3. Psychic vs Fighting [$k=25$, Accuracy: 86%]: Psychic Pokemon have high special attack but low attack, whereas fighting Pokemon have high attack but low special attack.

Note that achieving high accuracy is not true for all pairs of Pokemon types. Not all pairs of Pokemon types that seemed to be highly contrasted to each other in the show, are accurately classified using k-NN classification algorithm. This may suggest that despite these types being contrasted in the show, the difference in their basic stats as Pokemon types may not be that significant.

5.3 Legendaries can be classified based on their basic stats.

In classifying the legendary and non-legendary Pokemon, we implemented the k-NN classification algorithm twice with different setups.

  1. We considered ALL 800 Pokemon, however, only 65 of these are legendary Pokemon.
  2. We chose all legendary Pokemon. Then, we randomly selected 65 non-legendary Pokemon from a pool of 735 non-Legendary Pokemon. Then, we combined them into a dataframe of 130 legendary and non-legendary Pokemon.

It turns out that the accuracy is relatively high in both setups:

  • Setup 1: $k=27$, Accuracy: 94%
  • Setup 2: $k=35$, Accuracy: 90%

This result suggests that basic stats of legendary Pokemon are distinguishable from those of non-legendary Pokemon. Particularly, we may safely assume that legendary Pokemon are significantly superior in basic stats than the non-legendary Pokemon.

In [ ]: