{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "view-in-github",
    "jupyter": {
     "source_hidden": true
    },
    "tags": []
   },
   "source": [
    "<a href=\"https://colab.research.google.com/drive/1_-DJiIhyYnmAaIO1y2j7YON3er97GDio\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "gIT7vmg-1lnu"
   },
   "source": [
    "# Unsupervised Learning\n",
    "\n",
    "In many applications, observations need to be divided into similar groups based on observed features. For example, retailers may want to divide potential customers into groups, in order to target a marketing campaign at the customers who are most likely to respond positively. This practice is known as _market segmentation_.\n",
    "\n",
    "The general problem of grouping observations based on observed features is known as _clustering_ in machine learning. Like the classification problems of Chapter 6, clustering is about dividing observations into categories based on features. Unlike classification, we do not have ground truth labels that specify what the categories should be; they have to be inferred from the data. In other words, with classification, the training data contains both features $X$ and labels $y$; with clustering, the training data only contains features $X$.\n",
    "\n",
    "For this reason, clustering is an example of an _unsupervised learning_ problem, in contrast to the _supervised learning_ problems of the previous chapters. This terminology comes from the following analogy to human learning.  Imagine a child that is trying to learn the difference between shapes and has several examples of each shape in front of him.\n",
    "\n",
    "![](https://github.com/dlsun/pods/blob/master/07-Unsupervised-Learning/shape_sorter.jpg?raw=1)\n",
    "\n",
    "On the one hand, the child may be _supervised_ by an adult who gives the child feedback on each answer: \"Yes, that is a circle....No, that was a square....No, that was actually a circle....\"  This process is analogous to classification, where the labels in the training data can be used to provide \"feedback\" on how well the model is doing. Regression and classification are both examples of _supervised learning_ because labels are available in the training data.\n",
    "\n",
    "On the other hand, the child may be _unsupervised_ and completely left to his own devices. Eventually, he may figure out that there is something similar about all of the circles that distinguish them from the squares. But he won't know that they are called \"circles\", nor will he know whether he is right or not. This is the fundamental challenge of unsupervised learning. Clustering is an example of _unsupervised learning_ because labels are not available in the training data.\n",
    "\n",
    "We will practice clustering on a dataset containing measurements of 150 iris flowers, collected by the statistician R. A. Fisher."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 206
    },
    "id": "cQ24UxaQ1lny",
    "outputId": "d0b28b75-d026-4b59-e4ad-8db0b211ffd4"
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>SepalLength</th>\n",
       "      <th>SepalWidth</th>\n",
       "      <th>PetalLength</th>\n",
       "      <th>PetalWidth</th>\n",
       "      <th>Name</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>5.1</td>\n",
       "      <td>3.5</td>\n",
       "      <td>1.4</td>\n",
       "      <td>0.2</td>\n",
       "      <td>Iris-setosa</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>4.9</td>\n",
       "      <td>3.0</td>\n",
       "      <td>1.4</td>\n",
       "      <td>0.2</td>\n",
       "      <td>Iris-setosa</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>4.7</td>\n",
       "      <td>3.2</td>\n",
       "      <td>1.3</td>\n",
       "      <td>0.2</td>\n",
       "      <td>Iris-setosa</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>4.6</td>\n",
       "      <td>3.1</td>\n",
       "      <td>1.5</td>\n",
       "      <td>0.2</td>\n",
       "      <td>Iris-setosa</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>5.0</td>\n",
       "      <td>3.6</td>\n",
       "      <td>1.4</td>\n",
       "      <td>0.2</td>\n",
       "      <td>Iris-setosa</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   SepalLength  SepalWidth  PetalLength  PetalWidth         Name\n",
       "0          5.1         3.5          1.4         0.2  Iris-setosa\n",
       "1          4.9         3.0          1.4         0.2  Iris-setosa\n",
       "2          4.7         3.2          1.3         0.2  Iris-setosa\n",
       "3          4.6         3.1          1.5         0.2  Iris-setosa\n",
       "4          5.0         3.6          1.4         0.2  Iris-setosa"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "data_dir = \"https://dlsun.github.io/pods/data/\"\n",
    "df_iris = pd.read_csv(data_dir + \"iris.csv\")\n",
    "df_iris.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 209
    },
    "id": "xEVf7Frrq02Q",
    "outputId": "7c695181-b458-4bec-b59d-bb7d13747c9b"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Iris-setosa        50\n",
       "Iris-versicolor    50\n",
       "Iris-virginica     50\n",
       "Name: Name, dtype: int64"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_iris.Name.value_counts()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "-RrIp1_e1ln4"
   },
   "source": [
    "Let's focus on just two of the variables, the petal length and width, so that we can easily visualize the data. Based on the scatterplot below, how many clusters are there in this data set? Can you devise an algorithm that would automatically identify those clusters?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 466
    },
    "id": "nwnlgfeh1ln5",
    "outputId": "17ccd44b-3699-4589-e5df-44390e809dd0"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<AxesSubplot:xlabel='PetalLength', ylabel='PetalWidth'>"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 864x504 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "X_train = df_iris[[\"PetalLength\", \"PetalWidth\"]]\n",
    "X_train.plot.scatter(x=\"PetalLength\", y=\"PetalWidth\",\n",
    "                     c=\"black\", marker=\"x\", alpha=.5, figsize = (12,7))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "8xnkQALwq02S"
   },
   "source": [
    "`Iris` is a rare dataset that can be used for clusering that does come with `ground truth` - ideally, we want to clusters to break our data into the three Iris flower species.  We can see what this looks like below\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 522
    },
    "id": "8kHaVDOzq02S",
    "outputId": "12737139-68a6-46c1-c5d6-7cb5a212a7f4"
   },
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 504x432 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "colors = np.array(df_iris['Name'].map({\"Iris-setosa\": \"blue\",\n",
    "                              \"Iris-virginica\":\"green\",\n",
    "                              \"Iris-versicolor\":\"red\"}))\n",
    "\n",
    "f1 = plt.figure(figsize=(7,6))\n",
    "\n",
    "plt.scatter(x = X_train[\"PetalLength\"],\n",
    "            y= X_train[\"PetalWidth\"],\n",
    "            marker=\"x\", alpha=.9, c = colors,\n",
    "            s = 80)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ovyRxHuO1ln-"
   },
   "source": [
    "# $K$-Means Clustering"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "8wjeNsD21loB"
   },
   "source": [
    "$K$-means is an algorithm for finding clusters in data. The idea behind $k$-means is simple: each cluster has a \"center\" point called the **centroid**, and each observation is associated with the cluster of its nearest centroid. The challenge is finding those centroids. The $k$-means algorithm starts with a random guess for the centroids and iteratively improves them.\n",
    "\n",
    "The steps are as follows:\n",
    "\n",
    "1. Initialize $k$ centroids at random.\n",
    "2. Assign each point to the cluster of its nearest centroid.\n",
    "3. (After reassignment, each centroid may no longer be at the center of its cluster.) Recompute each centroid based on the points assigned to its cluster.\n",
    "4. Repeat steps 2 and 3 until no points change clusters.\n",
    "\n",
    "# Implementing K-Means from Scratch\n",
    "\n",
    "First, we will implement the $k$-means algorithm from scratch. First, let's sample 3 points at random from the iris data to serve as the initial centroids."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 575
    },
    "id": "-IhlAYDk1loC",
    "outputId": "50f3cd41-49f3-4b9d-b24a-fe41b4dec76e"
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>PetalLength</th>\n",
       "      <th>PetalWidth</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>r</th>\n",
       "      <td>6.1</td>\n",
       "      <td>2.3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>b</th>\n",
       "      <td>4.9</td>\n",
       "      <td>1.8</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>g</th>\n",
       "      <td>4.0</td>\n",
       "      <td>1.2</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   PetalLength  PetalWidth\n",
       "r          6.1         2.3\n",
       "b          4.9         1.8\n",
       "g          4.0         1.2"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Initialize 3 centroids at random from the data.\n",
    "centroids = X_train.sample(3)\n",
    "\n",
    "# Call the three clusters \"red\", \"blue\", \"green\" for convenience.\n",
    "centroids.index = [\"r\", \"b\", \"g\"]\n",
    "\n",
    "# Plot these centroids.\n",
    "ax = X_train.plot.scatter(x=\"PetalLength\", y=\"PetalWidth\",\n",
    "                          c=\"black\", marker=\"x\", alpha=.5)\n",
    "centroids.plot.scatter(x=\"PetalLength\", y=\"PetalWidth\",\n",
    "                       c=centroids.index, s = 100, ax=ax)\n",
    "\n",
    "centroids"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "eUMVbJrS1loI"
   },
   "source": [
    "Now we assign each point to the cluster of its nearest centroid.\n",
    "\n",
    "First, let's run through a simple example. Then we will generalize it.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 147
    },
    "id": "1gsOzjvwq02T",
    "outputId": "22db0748-fd42-4d20-95fe-b78acc1afe9d"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "PetalLength    4.4\n",
       "PetalWidth     1.2\n",
       "Name: 90, dtype: float64"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tmp = X_train.iloc[90]\n",
    "\n",
    "tmp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 178
    },
    "id": "EZSjAunXq02T",
    "outputId": "0fea4cad-c7e1-4ef4-8024-544f5e69d6ab"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "r    4.10\n",
       "b    0.61\n",
       "g    0.16\n",
       "dtype: float64"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "((tmp - centroids)**2).sum(axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 522
    },
    "id": "ZQegd5DUq02U",
    "outputId": "76b4ebc8-8deb-4143-fbcf-9556fa9040a0"
   },
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 648x432 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "f2 = plt.figure(figsize = (9,6))\n",
    "\n",
    "### plot all points\n",
    "plt.scatter(x= X_train[\"PetalLength\"],\n",
    "            y= X_train[\"PetalWidth\"],\n",
    "            c=\"black\", marker=\"x\", s= 80, alpha=.5)\n",
    "\n",
    "### plot colored centroids\n",
    "plt.scatter(x=centroids[\"PetalLength\"],\n",
    "            y=centroids[\"PetalWidth\"],\n",
    "            c=centroids.index, s = 100)\n",
    "\n",
    "### plot the point we selected\n",
    "\n",
    "plt.scatter(x = tmp[\"PetalLength\"],\n",
    "            y = tmp[\"PetalWidth\"],\n",
    "            c = \"cyan\", marker = \"s\", s=200 )\n",
    "\n",
    "## plot distances to centroids\n",
    "\n",
    "for (p, clr) in zip(centroids.iloc, centroids.index):\n",
    "  x = np.array([tmp[\"PetalLength\"], p[\"PetalLength\"]])\n",
    "  y = np.array([tmp[\"PetalWidth\"], p[\"PetalWidth\"]])\n",
    "  plt.plot(x,y, color = clr,\n",
    "          linestyle = \"--\",\n",
    "          linewidth = 3)\n",
    "\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 36
    },
    "id": "J4MtsHos1loJ",
    "outputId": "7e75bcca-876b-4649-eba2-14c3a130b577"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'g'"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Finds the nearest centroid to a given observation.\n",
    "def get_nearest_centroid(obs):\n",
    "    dists = np.sqrt(((obs - centroids) ** 2).sum(axis=1))\n",
    "    return dists.idxmin()\n",
    "\n",
    "get_nearest_centroid(X_train.loc[90])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 466
    },
    "id": "33vamCDa1loN",
    "outputId": "d6eaf027-a345-4f90-9fbe-540eb91d41da"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<AxesSubplot:xlabel='PetalLength', ylabel='PetalWidth'>"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Apply the function to the entire data set.\n",
    "clusters = X_train.apply(get_nearest_centroid, axis=1)\n",
    "\n",
    "# Plot the cluster assignments.\n",
    "ax = X_train.plot.scatter(x=\"PetalLength\", y=\"PetalWidth\",\n",
    "                          c=clusters, marker=\"x\", alpha=.5)\n",
    "centroids.plot.scatter(x=\"PetalLength\", y=\"PetalWidth\",\n",
    "                       c=centroids.index, s = 100, ax=ax)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "7GTX3HBb1loS"
   },
   "source": [
    "Notice that some of the centroids are not at the center of their clusters. We can fix that by redefining the centroid to be the mean of the points in its cluster."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 143
    },
    "id": "nub9Ke5Yq02V",
    "outputId": "e087dce0-1239-4d7e-dec3-f429ebd1d65c"
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>PetalLength</th>\n",
       "      <th>PetalWidth</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>b</th>\n",
       "      <td>4.931111</td>\n",
       "      <td>1.7200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>g</th>\n",
       "      <td>2.405000</td>\n",
       "      <td>0.6075</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>r</th>\n",
       "      <td>5.980000</td>\n",
       "      <td>2.1520</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   PetalLength  PetalWidth\n",
       "b     4.931111      1.7200\n",
       "g     2.405000      0.6075\n",
       "r     5.980000      2.1520"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X_train.groupby(clusters).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 575
    },
    "id": "ZZ8zXdpA1loT",
    "outputId": "8beb9ed9-6f27-4959-ecb9-b8a83bb70193"
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>PetalLength</th>\n",
       "      <th>PetalWidth</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>b</th>\n",
       "      <td>4.931111</td>\n",
       "      <td>1.7200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>g</th>\n",
       "      <td>2.405000</td>\n",
       "      <td>0.6075</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>r</th>\n",
       "      <td>5.980000</td>\n",
       "      <td>2.1520</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   PetalLength  PetalWidth\n",
       "b     4.931111      1.7200\n",
       "g     2.405000      0.6075\n",
       "r     5.980000      2.1520"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Calculate the mean length and width for each cluster.\n",
    "\n",
    "old_centroids = centroids.copy()\n",
    "\n",
    "centroids = X_train.groupby(clusters).mean()\n",
    "\n",
    "# Let's plot the new centroids.\n",
    "ax = X_train.plot.scatter(x=\"PetalLength\", y=\"PetalWidth\",\n",
    "                          c=clusters, marker=\"x\", alpha=.5)\n",
    "old_centroids.plot.scatter(x=\"PetalLength\", y = \"PetalWidth\",\n",
    "                           c = old_centroids.index, s =100, marker = \"s\", ax=ax)\n",
    "centroids.plot.scatter(x=\"PetalLength\", y=\"PetalWidth\",\n",
    "                       c=centroids.index, s = 100, ax=ax)\n",
    "\n",
    "centroids"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "53wheguP1loW"
   },
   "source": [
    "Now, there may be some points that are no longer assigned to their closest centroid, so we have to go back and re-assign clusters. But that may cause the centroids to no longer be at the center of their cluster, so we have to recalculate the centroids. And so on. This process continues until the cluster assignments stop changing."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 898
    },
    "id": "-6pANF0j1loY",
    "outputId": "c360a1f8-1ee5-46df-db4e-f6603e4fb290"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<AxesSubplot:xlabel='PetalLength', ylabel='PetalWidth'>"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "\n",
    "old_clusters = clusters.copy()\n",
    "\n",
    "# Assign points to their nearest centroid.\n",
    "clusters = X_train.apply(get_nearest_centroid, axis=1)\n",
    "\n",
    "# Recalculate the centroids based on the clusters.\n",
    "centroids = X_train.groupby(clusters).mean()\n",
    "\n",
    "# Plot the current cluster assignments and the centroids.\n",
    "ax = X_train.plot.scatter(x=\"PetalLength\", y=\"PetalWidth\",\n",
    "                          c=clusters, marker=\"x\", alpha=.5)\n",
    "centroids.plot.scatter(x=\"PetalLength\", y=\"PetalWidth\",\n",
    "                       c=centroids.index, s=100, ax=ax)\n",
    "\n",
    "\n",
    "\n",
    "# Plot the current cluster assignments and the centroids.\n",
    "ax2 = X_train.plot.scatter(x=\"PetalLength\", y=\"PetalWidth\",\n",
    "                          c=old_clusters, marker=\"x\", alpha=.5)\n",
    "centroids.plot.scatter(x=\"PetalLength\", y=\"PetalWidth\",\n",
    "                       c=centroids.index, s=100, ax=ax2)\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "tsC4v8ZR1lob"
   },
   "source": [
    "We can run the code in the above cell over and over until the clusters stop changing. This is the final cluster assignment.\n",
    "\n",
    "It is not so easy to visualize the clusters when there are more than 2 features. But we can wrap the same algorithm inside a loop that continues until the cluster assignments do not change from one step to the next. One of the exercises below walks you through such an implementation.\n",
    "\n",
    "# K-Means in _scikit-learn_\n",
    "\n",
    "We rarely need to implement the $k$-means algorithm from scratch because it is available in _scikit-learn_. The API for _scikit-learn_'s $k$-means model is similar to the API for supervised learning models, like $k$-nearest neighbors, except that the `.fit()` method only takes in `X`, not `X` and `y`. This makes sense because in unsupervised learning, there are no ground truth labels `y`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 74
    },
    "id": "S5zTk9MI1lod",
    "outputId": "dfb72f62-47a4-476c-ae56-64c10f864824"
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style>#sk-container-id-2 {color: black;}#sk-container-id-2 pre{padding: 0;}#sk-container-id-2 div.sk-toggleable {background-color: white;}#sk-container-id-2 label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.3em;box-sizing: border-box;text-align: center;}#sk-container-id-2 label.sk-toggleable__label-arrow:before {content: \"▸\";float: left;margin-right: 0.25em;color: #696969;}#sk-container-id-2 label.sk-toggleable__label-arrow:hover:before {color: black;}#sk-container-id-2 div.sk-estimator:hover label.sk-toggleable__label-arrow:before {color: black;}#sk-container-id-2 div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}#sk-container-id-2 div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}#sk-container-id-2 input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}#sk-container-id-2 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {content: \"▾\";}#sk-container-id-2 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-2 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-2 input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}#sk-container-id-2 div.sk-estimator {font-family: monospace;background-color: #f0f8ff;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;margin-bottom: 0.5em;}#sk-container-id-2 div.sk-estimator:hover {background-color: #d4ebff;}#sk-container-id-2 div.sk-parallel-item::after {content: \"\";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}#sk-container-id-2 div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-2 div.sk-serial::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: 0;}#sk-container-id-2 div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;padding-right: 0.2em;padding-left: 0.2em;position: relative;}#sk-container-id-2 div.sk-item {position: relative;z-index: 1;}#sk-container-id-2 div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;position: relative;}#sk-container-id-2 div.sk-item::before, #sk-container-id-2 div.sk-parallel-item::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: -1;}#sk-container-id-2 div.sk-parallel-item {display: flex;flex-direction: column;z-index: 1;position: relative;background-color: white;}#sk-container-id-2 div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}#sk-container-id-2 div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}#sk-container-id-2 div.sk-parallel-item:only-child::after {width: 0;}#sk-container-id-2 div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0 0.4em 0.5em 0.4em;box-sizing: border-box;padding-bottom: 0.4em;background-color: white;}#sk-container-id-2 div.sk-label label {font-family: monospace;font-weight: bold;display: inline-block;line-height: 1.2em;}#sk-container-id-2 div.sk-label-container {text-align: center;}#sk-container-id-2 div.sk-container {/* jupyter's `normalize.less` sets `[hidden] { display: none; }` but bootstrap.min.css set `[hidden] { display: none !important; }` so we also need the `!important` here to be able to override the default hidden behavior on the sphinx rendered scikit-learn.org. See: https://github.com/scikit-learn/scikit-learn/issues/21755 */display: inline-block !important;position: relative;}#sk-container-id-2 div.sk-text-repr-fallback {display: none;}</style><div id=\"sk-container-id-2\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>KMeans(init=&#x27;random&#x27;, n_clusters=3, n_init=1, random_state=100)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-2\" type=\"checkbox\" checked><label for=\"sk-estimator-id-2\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">KMeans</label><div class=\"sk-toggleable__content\"><pre>KMeans(init=&#x27;random&#x27;, n_clusters=3, n_init=1, random_state=100)</pre></div></div></div></div></div>"
      ],
      "text/plain": [
       "KMeans(init='random', n_clusters=3, n_init=1, random_state=100)"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.cluster import KMeans\n",
    "\n",
    "model = KMeans(n_clusters=3, init= 'random', n_init=1, random_state = 100)\n",
    "model.fit(X_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "Xywvpg_y1lol",
    "outputId": "420cd7d7-824f-4057-eef5-9282c907690a"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([[4.26923077, 1.34230769],\n",
       "        [1.464     , 0.244     ],\n",
       "        [5.59583333, 2.0375    ]]),\n",
       " array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "        1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0,\n",
       "        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2,\n",
       "        2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2, 2, 2,\n",
       "        2, 2, 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2], dtype=int32))"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Extract the centroids and the clusters.\n",
    "centroids = model.cluster_centers_\n",
    "clusters = model.labels_\n",
    "\n",
    "centroids, clusters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 466
    },
    "id": "qtWZ0V7v1loq",
    "outputId": "615dd92a-333b-4782-88e8-b8c77fe61cb9"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<Axes: xlabel='PetalLength', ylabel='PetalWidth'>"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Map the cluster numbers to colors.\n",
    "clusters = pd.Series(clusters).map({\n",
    "    0: \"r\",\n",
    "    1: \"b\",\n",
    "    2: \"g\"\n",
    "})\n",
    "\n",
    "# Plot the data\n",
    "X_train.plot.scatter(x=\"PetalLength\", y=\"PetalWidth\",\n",
    "                     c=clusters, marker=\"x\", alpha=.5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "FiaAIkib1lou"
   },
   "source": [
    "We can call `model.predict()` to get the cluster assignment for a new observation. This will simply assign the new observation to the nearest cluster without recalculating the centroids. (If this observation had been in the training data, then assigning the new observation to a cluster would require recalculating the centroid, which would in turn require reassigning observations to clusters, and so on.)\n",
    "\n",
    "For example, consider a flower whose petal has a length of 5.0 and a width of 0.5. It's obvious which cluster this point should be assigned to. Let's check that this is indeed the case, by calling `.predict()` on our fitted model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "7LtOvWDW1lov"
   },
   "outputs": [],
   "source": [
    "model.predict([[5.0, 0.5]])"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}