{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Solving classification problems with CatBoost" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/catboost/tutorials/blob/master/events/pydata_la_oct_21_2018.ipynb)\n", "\n", "In this tutorial we will use dataset Amazon Employee Access Challenge from [Kaggle](https://www.kaggle.com) competition for our experiments. Data can be downloaded [here](https://www.kaggle.com/c/amazon-employee-access-challenge/data)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Libraries installation" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "#!pip install --user --upgrade catboost\n", "#!pip install --user --upgrade ipywidgets\n", "#!pip install shap\n", "#!pip install sklearn\n", "#!pip install --upgrade numpy\n", "#!pip install --upgrade pandas\n", "#!jupyter nbextension enable --py widgetsnbextension" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.11.2\n", "Python 2.7.12\r\n" ] } ], "source": [ "import catboost\n", "print(catboost.__version__)\n", "!python --version" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Reading the data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The most simple way — read everything in pandas data frame" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "import os\n", "import numpy as np\n", "np.set_printoptions(precision=4)\n", "import catboost\n", "from catboost import *\n", "from catboost import datasets" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "# import ssl\n", "# ssl._create_default_https_context = ssl._create_unverified_context" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "(train_df, test_df) = catboost.datasets.amazon()" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
ACTIONRESOURCEMGR_IDROLE_ROLLUP_1ROLE_ROLLUP_2ROLE_DEPTNAMEROLE_TITLEROLE_FAMILY_DESCROLE_FAMILYROLE_CODE
013935385475117961118300123472117905117906290919117908
11171831540117961118343123125118536118536308574118539
21367241445711821911822011788411787926795219721117880
31361355396117961118343119993118321240983290919118322
4142680590511792911793011956911932312393219793119325
\n", "
" ], "text/plain": [ " ACTION RESOURCE MGR_ID ROLE_ROLLUP_1 ROLE_ROLLUP_2 ROLE_DEPTNAME \\\n", "0 1 39353 85475 117961 118300 123472 \n", "1 1 17183 1540 117961 118343 123125 \n", "2 1 36724 14457 118219 118220 117884 \n", "3 1 36135 5396 117961 118343 119993 \n", "4 1 42680 5905 117929 117930 119569 \n", "\n", " ROLE_TITLE ROLE_FAMILY_DESC ROLE_FAMILY ROLE_CODE \n", "0 117905 117906 290919 117908 \n", "1 118536 118536 308574 118539 \n", "2 117879 267952 19721 117880 \n", "3 118321 240983 290919 118322 \n", "4 119323 123932 19793 119325 " ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_df.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Preparing your data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Label values extraction" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "y = train_df.ACTION\n", "X = train_df.drop('ACTION', axis=1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Categorical features declaration" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[0, 1, 2, 3, 4, 5, 6, 7, 8]\n" ] } ], "source": [ "cat_features = list(range(0, X.shape[1]))\n", "print(cat_features)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Looking on label balance in dataset" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "('Labels: ', set([0, 1]))\n", "Zero count = 1897, One count = 30872\n" ] } ], "source": [ "print('Labels: ', set(y))\n", "print('Zero count = ' + str(len(y) - sum(y)) + ', One count = ' + str(sum(y)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To train model in CatBoost we need to create wrapper class for data: Pool.\n", "This class stores the data in CatBoost internal format.\n", "\n", "There exists several ways to create pool. \n", "The most simple one: create it from pandas dataframe or numpy array" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "pool1 = Pool(data=X, \n", " label=y,\n", " cat_features=cat_features) #Indicies of categorical columns in X" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This approach is not the most efficient, especially for big dataset: we'll need to copy everything from pandas to our internal format.\n", "\n", "So CatBoost could create Pool direclty from file" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Lets look how we could load Pool from file" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Firstly, lets save data frame to disk" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "dataset_dir = './amazon'\n", "if not os.path.exists(dataset_dir):\n", " os.makedirs(dataset_dir)\n", "\n", "# We will be able to work with files with/without header and with different separators.\n", "train_df.to_csv(os.path.join(dataset_dir, 'train.tsv'), index=False, sep='\\t', header=False)\n", "test_df.to_csv(os.path.join(dataset_dir, 'test.tsv'), index=False, sep='\\t', header=False)\n", "\n", "train_df.to_csv(os.path.join(dataset_dir, 'train.csv'), index=False, sep=',', header=True)\n", "test_df.to_csv(os.path.join(dataset_dir, 'test.csv'), index=False, sep=',', header=True)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1\t39353\t85475\t117961\t118300\t123472\t117905\t117906\t290919\t117908\r\n", "1\t17183\t1540\t117961\t118343\t123125\t118536\t118536\t308574\t118539\r\n" ] } ], "source": [ "!head -n2 amazon/train.tsv" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "ACTION,RESOURCE,MGR_ID,ROLE_ROLLUP_1,ROLE_ROLLUP_2,ROLE_DEPTNAME,ROLE_TITLE,ROLE_FAMILY_DESC,ROLE_FAMILY,ROLE_CODE\r\n", "1,39353,85475,117961,118300,123472,117905,117906,290919,117908\r\n", "1,17183,1540,117961,118343,123125,118536,118536,308574,118539\r\n" ] } ], "source": [ "!head -n3 amazon/train.csv" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we have dataset in 2 different formats:\n", "\n", "1) tab-separated without header\n", "\n", "2) comma-separated with header\n", "\n", "\n", "CatBoost, like pandas, could load data from different formats, we just need to pass proper options\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Before we load data, we need to set types of each column\n", "\n", "Also, we need to specify columns type. For this CatBoost uses special file, column description\n", "And we have helper-function to easily do this" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "from catboost.utils import create_cd\n", "\n", "feature_names = dict()\n", "for column, name in enumerate(train_df):\n", " if column == 0:\n", " continue\n", " feature_names[column] = name\n", " \n", "create_cd(\n", " label=0, \n", " cat_features=list(range(1, train_df.columns.shape[0])),\n", " feature_names=feature_names,\n", " output_path=os.path.join(dataset_dir, 'train.cd')\n", ")" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0\tLabel\t\r\n", "1\tCateg\tRESOURCE\r\n", "2\tCateg\tMGR_ID\r\n", "3\tCateg\tROLE_ROLLUP_1\r\n", "4\tCateg\tROLE_ROLLUP_2\r\n", "5\tCateg\tROLE_DEPTNAME\r\n", "6\tCateg\tROLE_TITLE\r\n", "7\tCateg\tROLE_FAMILY_DESC\r\n", "8\tCateg\tROLE_FAMILY\r\n", "9\tCateg\tROLE_CODE\r\n" ] } ], "source": [ "!cat amazon/train.cd" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "create_cd(\n", " label=0, \n", " cat_features=list(range(1, train_df.columns.shape[0])),\n", " # feature_names=feature_names,\n", " output_path=os.path.join(dataset_dir, 'train_without_names.cd')\n", ")" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0\tLabel\t\r\n", "1\tCateg\t\r\n", "2\tCateg\t\r\n", "3\tCateg\t\r\n", "4\tCateg\t\r\n", "5\tCateg\t\r\n", "6\tCateg\t\r\n", "7\tCateg\t\r\n", "8\tCateg\t\r\n", "9\tCateg\t\r\n" ] } ], "source": [ "!cat amazon/train_without_names.cd" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "create_cd(\n", " label=0, \n", " cat_features=list(range(2, train_df.columns.shape[0])),\n", " feature_names=feature_names,\n", " output_path=os.path.join(dataset_dir, 'train_with_num.cd')\n", ")" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0\tLabel\t\r\n", "1\tNum\tRESOURCE\r\n", "2\tCateg\tMGR_ID\r\n", "3\tCateg\tROLE_ROLLUP_1\r\n", "4\tCateg\tROLE_ROLLUP_2\r\n", "5\tCateg\tROLE_DEPTNAME\r\n", "6\tCateg\tROLE_TITLE\r\n", "7\tCateg\tROLE_FAMILY_DESC\r\n", "8\tCateg\tROLE_FAMILY\r\n", "9\tCateg\tROLE_CODE\r\n" ] } ], "source": [ "!cat amazon/train_with_num.cd" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "? create_cd" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's load pool from file now:" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "pool2 = Pool(\n", " data=os.path.join(dataset_dir, 'train.tsv'), \n", " #delimiter=',', \n", " column_description=os.path.join(dataset_dir, 'train.cd'),\n", " # has_header=True\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Loading pool from file is the fastest way to build Pool if you don't have Pool in RAM yet" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Exercices: load the same pools from csv file" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, if you want maximum performance and you data is already in RAM, than in some cases we could do better, than simply passing dataframe to Pool constructor\n", "\n", "We have class FeaturesData that is a fast way to pass data from numpy matrices to catboost" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "? FeaturesData" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [], "source": [ "# Fastest way to create a Pool is to create it from numpy matrix. This way should be used if you want fast predictions\n", "# or fastest way to load the data in python.\n", "\n", "X_prepared = X.values.astype(str).astype(object)\n", "# For FeaturesData class categorial features must have type str\n", "\n", "pool3 = Pool(\n", " data=FeaturesData(cat_feature_data=X_prepared, cat_feature_names=list(X)),\n", " label=y.values\n", ")\n" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Dataset shape\n", "dataset 1:(32769, 9)\n", "dataset 2:(32769, 9)\n", "dataset 3:(32769, 9)\n", "\n", "\n", "Column names\n", "dataset 1:\n", "['RESOURCE', 'MGR_ID', 'ROLE_ROLLUP_1', 'ROLE_ROLLUP_2', 'ROLE_DEPTNAME', 'ROLE_TITLE', 'ROLE_FAMILY_DESC', 'ROLE_FAMILY', 'ROLE_CODE']\n", "\n", "dataset 2:\n", "['RESOURCE', 'MGR_ID', 'ROLE_ROLLUP_1', 'ROLE_ROLLUP_2', 'ROLE_DEPTNAME', 'ROLE_TITLE', 'ROLE_FAMILY_DESC', 'ROLE_FAMILY', 'ROLE_CODE']\n", "\n", "dataset 3:\n", "['RESOURCE', 'MGR_ID', 'ROLE_ROLLUP_1', 'ROLE_ROLLUP_2', 'ROLE_DEPTNAME', 'ROLE_TITLE', 'ROLE_FAMILY_DESC', 'ROLE_FAMILY', 'ROLE_CODE']\n" ] } ], "source": [ "print('Dataset shape')\n", "print('dataset 1:' + str(pool1.shape) + '\\ndataset 2:' + str(pool2.shape) + \n", " '\\ndataset 3:' + str(pool3.shape))\n", "\n", "print('\\n')\n", "print('Column names')\n", "print('dataset 1:')\n", "print(pool1.get_feature_names()) \n", "print('\\ndataset 2:')\n", "print(pool2.get_feature_names())\n", "print('\\ndataset 3:')\n", "print(pool3.get_feature_names())\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Split your data into train and validation" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/annaveronika/.local/lib/python2.7/site-packages/sklearn/model_selection/_split.py:2179: FutureWarning: From version 0.21, test_size will always complement train_size unless both are specified.\n", " FutureWarning)\n" ] } ], "source": [ "from sklearn.model_selection import train_test_split\n", "X_train, X_validation, y_train, y_validation = train_test_split(X, y, train_size=0.8, random_state=1234)" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
RESOURCEMGR_IDROLE_ROLLUP_1ROLE_ROLLUP_2ROLE_DEPTNAMEROLE_TITLEROLE_FAMILY_DESCROLE_FAMILYROLE_CODE
1492674463105908117961118225129617118702132654118704118705
794017278120340120342120343119076118834311236118424118836
247687932517733117961118300119984118890125128118398118892
163351733075117961117962120677120357120678118424120359
2743156723745117961118300118360124435118362118363124436
\n", "
" ], "text/plain": [ " RESOURCE MGR_ID ROLE_ROLLUP_1 ROLE_ROLLUP_2 ROLE_DEPTNAME \\\n", "14926 74463 105908 117961 118225 129617 \n", "7940 17278 120340 120342 120343 119076 \n", "24768 79325 17733 117961 118300 119984 \n", "1633 5173 3075 117961 117962 120677 \n", "2743 15672 3745 117961 118300 118360 \n", "\n", " ROLE_TITLE ROLE_FAMILY_DESC ROLE_FAMILY ROLE_CODE \n", "14926 118702 132654 118704 118705 \n", "7940 118834 311236 118424 118836 \n", "24768 118890 125128 118398 118892 \n", "1633 120357 120678 118424 120359 \n", "2743 124435 118362 118363 124436 " ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X_train.head()" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(26215, 9)" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X_train.shape" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(6554, 9)" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X_validation.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Selecting the objective function" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Possible options for binary classification:\n", "\n", "`Logloss`\n", "\n", "`CrossEntropy` for probabilities in target" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from catboost import CatBoostClassifier\n", "\n", "model = CatBoostClassifier(\n", " iterations=5,\n", " learning_rate=0.1,\n", " #loss_function='Logloss',\n", " #loss_function='CrossEntropy'\n", ")\n", "\n", "train_pool = Pool(data=X_train, \n", " label=y_train, \n", " cat_features=cat_features)\n", "\n", "validation_pool = Pool(data=X_validation, \n", " label=y_validation, \n", " cat_features=cat_features)\n", "model.fit(\n", " train_pool,\n", " eval_set=validation_pool,\n", " verbose=False\n", ")" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model is fitted: True\n", "Model params:\n", "{'learning_rate': 0.1, 'loss_function': 'Logloss', 'iterations': 5}\n" ] } ], "source": [ "print('Model is fitted: ' + str(model.is_fitted()))\n", "print('Model params:')\n", "print(model.get_params())" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 31, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model = CatBoostClassifier(\n", " iterations=5,\n", " learning_rate=0.1,\n", " #loss_function='Logloss',\n", " #loss_function='CrossEntropy'\n", ")\n", "\n", "model.fit(\n", " X_train, y_train,\n", " cat_features=cat_features,\n", " eval_set=(X_validation, y_validation),\n", " verbose=False\n", ")" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model is fitted: True\n", "Model params:\n", "{'learning_rate': 0.1, 'loss_function': 'Logloss', 'iterations': 5}\n" ] } ], "source": [ "print('Model is fitted: ' + str(model.is_fitted()))\n", "print('Model params:')\n", "print(model.get_params())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Stdout of the training" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Learning rate set to 0.5\n", "0:\tlearn: 0.2985954\ttest: 0.2997087\tbest: 0.2997087 (0)\ttotal: 56.7ms\tremaining: 794ms\n", "1:\tlearn: 0.2148559\ttest: 0.2108328\tbest: 0.2108328 (1)\ttotal: 116ms\tremaining: 757ms\n", "2:\tlearn: 0.1841475\ttest: 0.1719133\tbest: 0.1719133 (2)\ttotal: 187ms\tremaining: 748ms\n", "3:\tlearn: 0.1760205\ttest: 0.1608898\tbest: 0.1608898 (3)\ttotal: 260ms\tremaining: 715ms\n", "4:\tlearn: 0.1724111\ttest: 0.1553231\tbest: 0.1553231 (4)\ttotal: 321ms\tremaining: 642ms\n", "5:\tlearn: 0.1676418\ttest: 0.1498856\tbest: 0.1498856 (5)\ttotal: 402ms\tremaining: 603ms\n", "6:\tlearn: 0.1660562\ttest: 0.1478320\tbest: 0.1478320 (6)\ttotal: 476ms\tremaining: 544ms\n", "7:\tlearn: 0.1653090\ttest: 0.1467205\tbest: 0.1467205 (7)\ttotal: 558ms\tremaining: 488ms\n", "8:\tlearn: 0.1649062\ttest: 0.1465720\tbest: 0.1465720 (8)\ttotal: 606ms\tremaining: 404ms\n", "9:\tlearn: 0.1643222\ttest: 0.1463720\tbest: 0.1463720 (9)\ttotal: 685ms\tremaining: 342ms\n", "10:\tlearn: 0.1633503\ttest: 0.1450909\tbest: 0.1450909 (10)\ttotal: 752ms\tremaining: 273ms\n", "11:\tlearn: 0.1632233\ttest: 0.1451206\tbest: 0.1450909 (10)\ttotal: 815ms\tremaining: 204ms\n", "12:\tlearn: 0.1626563\ttest: 0.1453396\tbest: 0.1450909 (10)\ttotal: 879ms\tremaining: 135ms\n", "13:\tlearn: 0.1620230\ttest: 0.1445742\tbest: 0.1445742 (13)\ttotal: 941ms\tremaining: 67.2ms\n", "14:\tlearn: 0.1614001\ttest: 0.1444070\tbest: 0.1444070 (14)\ttotal: 1.01s\tremaining: 0us\n", "\n", "bestTest = 0.1444069686\n", "bestIteration = 14\n", "\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 33, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from catboost import CatBoostClassifier\n", "model = CatBoostClassifier(\n", " iterations=15,\n", " #verbose=5,\n", " logging_level='Verbose'\n", ")\n", "model.fit(\n", " X_train, y_train,\n", " cat_features=cat_features,\n", " eval_set=(X_validation, y_validation),\n", ")" ] }, { "cell_type": "code", "execution_count": 34, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Learning rate set to 0.5\n", "\n", "{ROLE_ROLLUP_1} pr2 tb0 type0, border=11 score 72.71470337\n", "{ROLE_ROLLUP_1} pr1 tb0 type0, border=4 score 73.13891169\n", "{ROLE_ROLLUP_1, ROLE_ROLLUP_2} pr2 tb0 type0, border=11 score 72.65449439\n", "{ROLE_ROLLUP_1} pr1 tb0 type0, border=2 score 72.85344413\n", " tensor 3 is redundant, remove it and stop\n", "0:\tlearn: 0.2985954\ttest: 0.2997087\tbest: 0.2997087 (0)\ttotal: 55.5ms\tremaining: 776ms\n", "\n", "{ROLE_TITLE} pr2 tb0 type0, border=13 score 22.73425202\n", "{ROLE_DEPTNAME, ROLE_TITLE} pr2 tb0 type0, border=10 score 23.43007817\n", "{ROLE_DEPTNAME, ROLE_TITLE} pr1 tb0 type0, border=3 score 25.16989885\n", "{ROLE_TITLE} pr0 tb0 type0, border=2 score 25.40915672\n", "{ROLE_TITLE} pr0 tb0 type0, border=14 score 25.17279721\n", " tensor 4 is redundant, remove it and stop\n", "1:\tlearn: 0.2148559\ttest: 0.2108328\tbest: 0.2108328 (1)\ttotal: 118ms\tremaining: 769ms\n", "\n", "{MGR_ID} pr2 tb0 type0, border=11 score 10.02687185\n", "{MGR_ID, ROLE_FAMILY_DESC} pr2 tb0 type0, border=11 score 13.662788\n", "{ROLE_ROLLUP_2} pr2 tb0 type0, border=13 score 15.3551266\n", "{ROLE_ROLLUP_1} pr2 tb0 type0, border=8 score 15.4183659\n", "{MGR_ID, ROLE_FAMILY_DESC} pr1 tb0 type0, border=12 score 15.60179644\n", "{ROLE_TITLE} pr0 tb0 type0, border=3 score 15.46732072\n", "2:\tlearn: 0.1841475\ttest: 0.1719133\tbest: 0.1719133 (2)\ttotal: 187ms\tremaining: 748ms\n", "\n", "{MGR_ID} pr1 tb0 type0, border=6 score 5.040708161\n", "{MGR_ID, ROLE_FAMILY_DESC} pr2 tb0 type0, border=11 score 7.163894\n", "{MGR_ID, ROLE_TITLE, ROLE_FAMILY_DESC} pr2 tb0 type0, border=12 score 7.509251314\n", "{ROLE_TITLE} pr1 tb0 type0, border=3 score 8.40093362\n", "{ROLE_DEPTNAME, ROLE_TITLE} pr2 tb0 type0, border=2 score 8.54481688\n", "{MGR_ID} pr2 tb0 type0, border=5 score 8.489125383\n", "3:\tlearn: 0.1760205\ttest: 0.1608898\tbest: 0.1608898 (3)\ttotal: 253ms\tremaining: 694ms\n", "\n", "{MGR_ID} pr2 tb0 type0, border=9 score 3.878720433\n", "{RESOURCE} pr2 tb0 type0, border=10 score 4.787399099\n", "{MGR_ID, ROLE_FAMILY_DESC} pr2 tb0 type0, border=14 score 5.664541783\n", "{ROLE_CODE} pr2 tb0 type0, border=2 score 5.968804419\n", " tensor 3 is redundant, remove it and stop\n", "4:\tlearn: 0.1724111\ttest: 0.1553231\tbest: 0.1553231 (4)\ttotal: 304ms\tremaining: 607ms\n", "\n", "{RESOURCE} pr2 tb0 type0, border=11 score 2.468191783\n", "{RESOURCE, ROLE_DEPTNAME} pr1 tb0 type0, border=8 score 3.031268761\n", "{ROLE_ROLLUP_1} pr0 tb0 type1, border=1 score 4.007602059\n", "{ROLE_ROLLUP_1, ROLE_DEPTNAME} pr2 tb0 type0, border=6 score 4.169017801\n", "{ROLE_ROLLUP_1, ROLE_DEPTNAME, ROLE_FAMILY} pr2 tb0 type0, border=11 score 4.602341778\n", "{ROLE_ROLLUP_1} pr1 tb0 type0, border=11 score 5.038311598\n", "5:\tlearn: 0.1676418\ttest: 0.1498856\tbest: 0.1498856 (5)\ttotal: 379ms\tremaining: 568ms\n", "\n", "{ROLE_DEPTNAME} pr2 tb0 type0, border=6 score 1.91890204\n", "{RESOURCE, ROLE_DEPTNAME} pr1 tb0 type0, border=10 score 2.908847019\n", "{MGR_ID, ROLE_DEPTNAME} pr2 tb0 type0, border=12 score 3.606018456\n", "{RESOURCE} pr0 tb0 type0, border=13 score 4.307821626\n", "{ROLE_TITLE} pr0 tb0 type0, border=14 score 4.579777128\n", " tensor 4 is redundant, remove it and stop\n", "6:\tlearn: 0.1660562\ttest: 0.1478320\tbest: 0.1478320 (6)\ttotal: 445ms\tremaining: 508ms\n", "\n", "{RESOURCE} pr1 tb0 type0, border=1 score 2.258001596\n", "{RESOURCE, ROLE_DEPTNAME} pr2 tb0 type0, border=5 score 2.860735033\n", "{ROLE_TITLE} pr2 tb0 type0, border=5 score 3.18948996\n", "{ROLE_DEPTNAME, ROLE_TITLE} pr2 tb0 type0, border=13 score 3.944580691\n", "{ROLE_DEPTNAME, ROLE_TITLE, ROLE_FAMILY_DESC} pr2 tb0 type0, border=7 score 4.163339306\n", "{ROLE_ROLLUP_2} pr0 tb0 type0, border=14 score 4.519983493\n", " tensor 5 is redundant, remove it and stop\n", "7:\tlearn: 0.1653090\ttest: 0.1467205\tbest: 0.1467205 (7)\ttotal: 517ms\tremaining: 452ms\n", "\n", "{RESOURCE} pr1 tb0 type0, border=1 score 2.213992054\n", "{RESOURCE} pr2 tb0 type0, border=12 score 2.814380964\n", "{ROLE_CODE} pr2 tb0 type0, border=1 score 3.054719636\n", " tensor 2 is redundant, remove it and stop\n", "8:\tlearn: 0.1649062\ttest: 0.1465720\tbest: 0.1465720 (8)\ttotal: 560ms\tremaining: 373ms\n", "\n", "{RESOURCE} pr2 tb0 type0, border=6 score 2.129653303\n", "{ROLE_ROLLUP_2} pr0 tb0 type0, border=12 score 2.50659707\n", "{ROLE_ROLLUP_2, ROLE_DEPTNAME} pr1 tb0 type0, border=6 score 2.826564801\n", "{RESOURCE, ROLE_DEPTNAME} pr1 tb0 type0, border=7 score 3.864031498\n", "{ROLE_FAMILY} pr2 tb0 type0, border=0 score 3.785986162\n", " tensor 4 is redundant, remove it and stop\n", "9:\tlearn: 0.1643222\ttest: 0.1463720\tbest: 0.1463720 (9)\ttotal: 625ms\tremaining: 312ms\n", "\n", "{ROLE_DEPTNAME} pr1 tb0 type0, border=13 score 1.598079397\n", "{ROLE_DEPTNAME, ROLE_TITLE} pr1 tb0 type0, border=3 score 2.182451775\n", "{ROLE_DEPTNAME, ROLE_TITLE, ROLE_FAMILY_DESC} pr2 tb0 type0, border=12 score 3.45517723\n", "{MGR_ID} pr2 tb0 type0, border=1 score 4.087126108\n", "{ROLE_DEPTNAME, ROLE_TITLE} pr2 tb0 type0, border=9 score 4.330920437\n", "{ROLE_TITLE} pr2 tb0 type0, border=0 score 4.436657706\n", " tensor 5 is redundant, remove it and stop\n", "10:\tlearn: 0.1633503\ttest: 0.1450909\tbest: 0.1450909 (10)\ttotal: 687ms\tremaining: 250ms\n", "\n", "{ROLE_FAMILY_DESC} pr1 tb0 type0, border=4 score 1.470512618\n", "{ROLE_DEPTNAME} pr2 tb0 type0, border=13 score 2.656999431\n", "{ROLE_FAMILY} pr0 tb0 type1, border=1 score 3.102896991\n", "{ROLE_TITLE, ROLE_FAMILY} pr0 tb0 type0, border=14 score 3.5519837\n", " tensor 3 is redundant, remove it and stop\n", "11:\tlearn: 0.1632233\ttest: 0.1451206\tbest: 0.1450909 (10)\ttotal: 739ms\tremaining: 185ms\n", "\n", "{ROLE_ROLLUP_2} pr0 tb0 type0, border=9 score 1.63415229\n", "{ROLE_ROLLUP_2, ROLE_DEPTNAME} pr0 tb0 type0, border=6 score 1.841313561\n", "{ROLE_ROLLUP_2, ROLE_TITLE} pr2 tb0 type0, border=13 score 2.839776224\n", "{ROLE_ROLLUP_2} pr0 tb0 type1, border=6 score 3.034144356\n", "{ROLE_ROLLUP_2, ROLE_DEPTNAME, ROLE_FAMILY_DESC} pr2 tb0 type0, border=9 score 3.381467012\n", "{ROLE_TITLE} pr1 tb0 type0, border=5 score 3.77508779\n", "12:\tlearn: 0.1626563\ttest: 0.1453396\tbest: 0.1450909 (10)\ttotal: 792ms\tremaining: 122ms\n", "\n", "{ROLE_FAMILY_DESC} pr2 tb0 type0, border=14 score 1.643727886\n", "{ROLE_TITLE, ROLE_FAMILY_DESC} pr1 tb0 type0, border=13 score 2.322308423\n", "{MGR_ID, ROLE_TITLE, ROLE_FAMILY_DESC} pr2 tb0 type0, border=6 score 2.494762742\n", "{MGR_ID, ROLE_ROLLUP_2, ROLE_TITLE, ROLE_FAMILY_DESC} pr2 tb0 type0, border=6 score 2.723569126\n", "{MGR_ID, ROLE_TITLE, ROLE_FAMILY_DESC} pr2 tb0 type0, border=1 score 3.439027181\n", "{ROLE_TITLE} pr1 tb0 type0, border=14 score 3.425081938\n", " tensor 5 is redundant, remove it and stop\n", "13:\tlearn: 0.1620230\ttest: 0.1445742\tbest: 0.1445742 (13)\ttotal: 863ms\tremaining: 61.6ms\n", "\n", "{ROLE_ROLLUP_2} pr2 tb0 type0, border=11 score 1.510460943\n", "{ROLE_ROLLUP_2, ROLE_FAMILY_DESC} pr1 tb0 type0, border=2 score 2.020092181\n", "{MGR_ID, ROLE_ROLLUP_2, ROLE_FAMILY_DESC} pr2 tb0 type0, border=6 score 2.705877322\n", "{ROLE_ROLLUP_2, ROLE_DEPTNAME} pr2 tb0 type0, border=9 score 3.767178056\n", "{ROLE_TITLE} pr0 tb0 type0, border=11 score 4.23047388\n", "{ROLE_TITLE, ROLE_FAMILY} pr0 tb0 type0, border=5 score 4.350755703\n", "14:\tlearn: 0.1614001\ttest: 0.1444070\tbest: 0.1444070 (14)\ttotal: 938ms\tremaining: 0us\n", "\n", "bestTest = 0.1444069686\n", "bestIteration = 14\n", "\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 34, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from catboost import CatBoostClassifier\n", "model = CatBoostClassifier(\n", " iterations=15,\n", " #verbose=5,\n", " logging_level='Info'\n", ")\n", "model.fit(\n", " X_train, y_train,\n", " cat_features=cat_features,\n", " eval_set=(X_validation, y_validation),\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Metrics calculation and graph plotting" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "50383c010c304ab4afe465a9ffee0c10", "version_major": 2, "version_minor": 0 }, "text/plain": [ "MetricVisualizer(layout=Layout(align_self=u'stretch', height=u'500px'))" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "" ] }, "execution_count": 35, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from catboost import CatBoostClassifier\n", "model = CatBoostClassifier(\n", " iterations=500,\n", " random_seed=63,\n", " learning_rate=0.5\n", ")\n", "model.fit(\n", " X_train, y_train,\n", " cat_features=cat_features,\n", " eval_set=(X_validation, y_validation),\n", " verbose=False,\n", " plot=True\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Eval metric, custom metrics and best trees count" ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "6fe0b05ae929456c89fbd3046b7de131", "version_major": 2, "version_minor": 0 }, "text/plain": [ "MetricVisualizer(layout=Layout(align_self=u'stretch', height=u'500px'))" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "" ] }, "execution_count": 36, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from catboost import CatBoostClassifier\n", "model = CatBoostClassifier(\n", " iterations=50,\n", " random_seed=63,\n", " learning_rate=0.5,\n", " eval_metric=\"Accuracy\",\n", " use_best_model=False\n", ")\n", "\n", "model.fit(\n", " X_train, y_train,\n", " cat_features=cat_features,\n", " eval_set=(X_validation, y_validation),\n", " verbose=False,\n", " plot=True\n", ")" ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "50" ] }, "execution_count": 37, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.tree_count_" ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "8829f8544aa144c7a2e94df99bc88e1b", "version_major": 2, "version_minor": 0 }, "text/plain": [ "MetricVisualizer(layout=Layout(align_self=u'stretch', height=u'500px'))" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "" ] }, "execution_count": 38, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from catboost import CatBoostClassifier\n", "model = CatBoostClassifier(\n", " iterations=50,\n", " random_seed=63,\n", " learning_rate=0.5,\n", " custom_loss=['AUC', 'Accuracy']\n", ")\n", "model.fit(\n", " X_train, y_train,\n", " cat_features=cat_features,\n", " eval_set=(X_validation, y_validation),\n", " verbose=False,\n", " plot=True\n", ")" ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "21" ] }, "execution_count": 39, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model._tree_count" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Metric hints" ] }, { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "ebf229a449b040e4aa7e2d72147792f4", "version_major": 2, "version_minor": 0 }, "text/plain": [ "MetricVisualizer(layout=Layout(align_self=u'stretch', height=u'500px'))" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "" ] }, "execution_count": 40, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from catboost import CatBoostClassifier\n", "\n", "model = CatBoostClassifier(\n", " iterations=50,\n", " random_seed=63,\n", " learning_rate=0.5,\n", " eval_metric='AUC:hints=skip_train~false' #default\n", ")\n", "\n", "model.fit(\n", " X_train, y_train,\n", " cat_features=cat_features,\n", " eval_set=(X_validation, y_validation),\n", " verbose=False,\n", " plot=True\n", ")" ] }, { "cell_type": "code", "execution_count": 41, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "f10f866248e24cc991b55711dcb57d33", "version_major": 2, "version_minor": 0 }, "text/plain": [ "MetricVisualizer(layout=Layout(align_self=u'stretch', height=u'500px'))" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "" ] }, "execution_count": 41, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from catboost import CatBoostClassifier\n", "model = CatBoostClassifier(\n", " iterations=50,\n", " random_seed=63,\n", " learning_rate=0.5,\n", " eval_metric='AUC:hints=skip_train~false', #default\n", " metric_period=10\n", ")\n", "model.fit(\n", " X_train, y_train,\n", " cat_features=cat_features,\n", " eval_set=(X_validation, y_validation),\n", " verbose=False,\n", " plot=True\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Model comparison" ] }, { "cell_type": "code", "execution_count": 42, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 42, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model1 = CatBoostClassifier(\n", " learning_rate=0.5,\n", " iterations=100,\n", " train_dir='learing_rate_0.5'\n", ")\n", "\n", "model2 = CatBoostClassifier(\n", " learning_rate=0.01,\n", " iterations=100,\n", " train_dir='learing_rate_0.01'\n", ")\n", "\n", "model1.fit(\n", " X_train, y_train,\n", " eval_set=(X_validation, y_validation),\n", " cat_features=cat_features,\n", " verbose=False\n", ")\n", "model2.fit(\n", " X_train, y_train,\n", " eval_set=(X_validation, y_validation),\n", " cat_features=cat_features,\n", " verbose=False\n", ")" ] }, { "cell_type": "code", "execution_count": 43, "metadata": { "scrolled": false }, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "ea689d29144c46b48c7b6cb67d4fd6d7", "version_major": 2, "version_minor": 0 }, "text/plain": [ "MetricVisualizer(layout=Layout(align_self=u'stretch', height=u'500px'))" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from catboost import MetricVisualizer\n", "MetricVisualizer(['learing_rate_0.01', 'learing_rate_0.5']).start()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Overfitting detector" ] }, { "cell_type": "code", "execution_count": 44, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "c942cf28d09e49bd90c520160340dc3b", "version_major": 2, "version_minor": 0 }, "text/plain": [ "MetricVisualizer(layout=Layout(align_self=u'stretch', height=u'500px'))" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "" ] }, "execution_count": 44, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model_with_early_stop = CatBoostClassifier(\n", " iterations=200,\n", " random_seed=63,\n", " learning_rate=0.5,\n", " early_stopping_rounds=20\n", ")\n", "model_with_early_stop.fit(\n", " X_train, y_train,\n", " cat_features=cat_features,\n", " eval_set=(X_validation, y_validation),\n", " verbose=False,\n", " plot=True\n", ")" ] }, { "cell_type": "code", "execution_count": 45, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "21\n" ] } ], "source": [ "print(model_with_early_stop.tree_count_)" ] }, { "cell_type": "code", "execution_count": 46, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "6d94edb621784b11bd6d0886277d5c7d", "version_major": 2, "version_minor": 0 }, "text/plain": [ "MetricVisualizer(layout=Layout(align_self=u'stretch', height=u'500px'))" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "" ] }, "execution_count": 46, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model_with_early_stop = CatBoostClassifier(\n", " eval_metric='AUC',\n", " iterations=200,\n", " random_seed=63,\n", " learning_rate=0.5,\n", " early_stopping_rounds=20\n", ")\n", "model_with_early_stop.fit(\n", " X_train, y_train,\n", " cat_features=cat_features,\n", " eval_set=(X_validation, y_validation),\n", " verbose=False,\n", " plot=True\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Cross-validation" ] }, { "cell_type": "code", "execution_count": 47, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "e822f1fa217f40338ca0d0cbf140ec71", "version_major": 2, "version_minor": 0 }, "text/plain": [ "MetricVisualizer(layout=Layout(align_self=u'stretch', height=u'500px'))" ] }, "metadata": {}, "output_type": "display_data" }, { "ename": "CatBoostError", "evalue": "library/json/writer/json_value.cpp:515: Not an array", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mCatBoostError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[0mplot\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mTrue\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 18\u001b[0m \u001b[0mstratified\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mFalse\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 19\u001b[0;31m \u001b[0mverbose\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mFalse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 20\u001b[0m )\n", "\u001b[0;32m/home/annaveronika/trunk/arcadia/catboost/python-package/catboost/core.pyc\u001b[0m in \u001b[0;36mcv\u001b[0;34m(pool, params, dtrain, iterations, num_boost_round, fold_count, nfold, inverted, partition_random_seed, seed, shuffle, logging_level, stratified, as_pandas, metric_period, verbose, verbose_eval, plot, early_stopping_rounds, save_snapshot, snapshot_file, snapshot_interval, iterations_batch_size)\u001b[0m\n\u001b[1;32m 2928\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mlog_fixup\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mplot_wrapper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mplot\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2929\u001b[0m return _cv(params, pool, fold_count, inverted, partition_random_seed, shuffle, stratified,\n\u001b[0;32m-> 2930\u001b[0;31m as_pandas, iterations_batch_size)\n\u001b[0m\u001b[1;32m 2931\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2932\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m_catboost.pyx\u001b[0m in \u001b[0;36m_catboost._cv\u001b[0;34m()\u001b[0m\n", "\u001b[0;32m_catboost.pyx\u001b[0m in \u001b[0;36m_catboost._cv\u001b[0;34m()\u001b[0m\n", "\u001b[0;31mCatBoostError\u001b[0m: library/json/writer/json_value.cpp:515: Not an array" ] } ], "source": [ "from catboost import cv\n", "\n", "params = {}\n", "params['loss_function'] = 'Logloss'\n", "params['iterations'] = 80\n", "params['custom_loss'] = 'AUC'\n", "params['random_seed'] = 63\n", "params['learning_rate'] = 0.5\n", "\n", "cv_data = cv(\n", " params = params,\n", " pool = Pool(X, label=y, cat_features=cat_features),\n", " fold_count=5,\n", " inverted=False,\n", " shuffle=True,\n", " partition_random_seed=0,\n", " plot=True,\n", " stratified=False,\n", " verbose=False\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "cv_data" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "best_value = np.min(cv_data['test-Logloss-mean'])\n", "best_iter = cv_data['test-Logloss-mean'].idxmin()\n", "\n", "print('Best validation Logloss score, not stratified: {:.4f}±{:.4f} on step {}'.format(\n", " best_value,\n", " cv_data['test-Logloss-std'][best_iter],\n", " best_iter)\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "cv_data = cv(\n", " params = params,\n", " pool = Pool(X, label=y, cat_features=cat_features),\n", " fold_count=5,\n", " type = 'Classical',\n", " shuffle=True,\n", " partition_random_seed=0,\n", " plot=True,\n", " stratified=True,\n", " verbose=False\n", ")\n", "\n", "best_value = np.min(cv_data['test-Logloss-mean'])\n", "best_iter = cv_data['test-Logloss-mean'].idxmin()\n", "\n", "print('Best validation Logloss score, not stratified: {:.4f}±{:.4f} on step {}'.format(\n", " best_value,\n", " cv_data['test-Logloss-std'][best_iter],\n", " best_iter)\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Select decision boundary" ] }, { "cell_type": "code", "execution_count": 48, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "2d3e64ab6c4349baa355f12ad740bc89", "version_major": 2, "version_minor": 0 }, "text/plain": [ "MetricVisualizer(layout=Layout(align_self=u'stretch', height=u'500px'))" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "" ] }, "execution_count": 48, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model = CatBoostClassifier(\n", " random_seed=63,\n", " iterations=200,\n", " learning_rate=0.03,\n", ")\n", "model.fit(\n", " X_train, y_train,\n", " cat_features=cat_features,\n", " verbose=False,\n", " plot=True\n", ")" ] }, { "cell_type": "code", "execution_count": 49, "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import sklearn\n", "from sklearn import metrics\n", "from catboost.utils import get_roc_curve" ] }, { "cell_type": "code", "execution_count": 50, "metadata": {}, "outputs": [], "source": [ "? get_roc_curve" ] }, { "cell_type": "code", "execution_count": 51, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "eval_pool = Pool(X_validation, y_validation, cat_features=cat_features)\n", "curve = get_roc_curve(model, eval_pool)\n", "(fpr, tpr, thresholds) = curve\n", "\n", "plt.figure()\n", "lw = 2\n", "roc_auc = sklearn.metrics.auc(fpr, tpr)\n", "plt.plot(fpr, tpr, color='darkorange',\n", " lw=lw, label='ROC curve (area = %0.2f)' % roc_auc)\n", "\n", "plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')\n", "plt.xlim([0.0, 1.0])\n", "plt.ylim([0.0, 1.05])\n", "plt.xlabel('False Positive Rate')\n", "plt.ylabel('True Positive Rate')\n", "plt.title('Receiver operating characteristic')\n", "plt.legend(loc=\"lower right\")\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 52, "metadata": {}, "outputs": [], "source": [ "from catboost.utils import get_fpr_curve\n", "from catboost.utils import get_fnr_curve" ] }, { "cell_type": "code", "execution_count": 53, "metadata": {}, "outputs": [], "source": [ "? get_fpr_curve" ] }, { "cell_type": "code", "execution_count": 54, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.figure()\n", "lw = 2\n", "(thresholds, fpr) = get_fpr_curve(curve=curve)\n", "(thresholds, fnr) = get_fnr_curve(curve=curve)\n", "plt.plot(thresholds, fpr, color='blue', lw=lw, label='FPR')\n", "plt.plot(thresholds, fnr, color='green', lw=lw, label='FNR')\n", "\n", "plt.xlim([0.0, 1.0])\n", "plt.ylim([0.0, 1.05])\n", "plt.xlabel('Threshold')\n", "plt.ylabel('Error Rate')\n", "plt.title('FPR-FNR curves')\n", "plt.legend(loc=\"lower left\")\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 55, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.471901498177\n", "0.987647294429\n" ] } ], "source": [ "from catboost.utils import select_threshold\n", "\n", "print(select_threshold(model=model, data=eval_pool, FNR=0.01))\n", "print(select_threshold(model=model, data=eval_pool, FPR=0.01))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Snapshotting" ] }, { "cell_type": "code", "execution_count": 56, "metadata": {}, "outputs": [], "source": [ "#!rm 'catboost_info/snapshot.bkp'\n", "\n" ] }, { "cell_type": "code", "execution_count": 57, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Learning rate set to 0.5\n", "0:\tlearn: 0.2965875\ttest: 0.2955232\tbest: 0.2955232 (0)\ttotal: 59.9ms\tremaining: 1.14s\n", "1:\tlearn: 0.2160431\ttest: 0.2105519\tbest: 0.2105519 (1)\ttotal: 120ms\tremaining: 1.08s\n", "2:\tlearn: 0.1847090\ttest: 0.1727051\tbest: 0.1727051 (2)\ttotal: 195ms\tremaining: 1.11s\n", "3:\tlearn: 0.1761768\ttest: 0.1597829\tbest: 0.1597829 (3)\ttotal: 258ms\tremaining: 1.03s\n", "4:\tlearn: 0.1714326\ttest: 0.1548739\tbest: 0.1548739 (4)\ttotal: 325ms\tremaining: 976ms\n", "5:\tlearn: 0.1694150\ttest: 0.1517186\tbest: 0.1517186 (5)\ttotal: 388ms\tremaining: 905ms\n", "6:\tlearn: 0.1675766\ttest: 0.1498146\tbest: 0.1498146 (6)\ttotal: 459ms\tremaining: 852ms\n", "7:\tlearn: 0.1667392\ttest: 0.1488889\tbest: 0.1488889 (7)\ttotal: 518ms\tremaining: 777ms\n", "8:\tlearn: 0.1664645\ttest: 0.1489624\tbest: 0.1488889 (7)\ttotal: 570ms\tremaining: 697ms\n", "9:\tlearn: 0.1659541\ttest: 0.1484477\tbest: 0.1484477 (9)\ttotal: 636ms\tremaining: 636ms\n", "10:\tlearn: 0.1647458\ttest: 0.1472531\tbest: 0.1472531 (10)\ttotal: 698ms\tremaining: 571ms\n", "11:\tlearn: 0.1647261\ttest: 0.1471785\tbest: 0.1471785 (11)\ttotal: 748ms\tremaining: 498ms\n", "12:\tlearn: 0.1644109\ttest: 0.1466226\tbest: 0.1466226 (12)\ttotal: 806ms\tremaining: 434ms\n", "13:\tlearn: 0.1644007\ttest: 0.1466116\tbest: 0.1466116 (13)\ttotal: 853ms\tremaining: 366ms\n", "14:\tlearn: 0.1641944\ttest: 0.1464603\tbest: 0.1464603 (14)\ttotal: 907ms\tremaining: 302ms\n", "15:\tlearn: 0.1626904\ttest: 0.1452219\tbest: 0.1452219 (15)\ttotal: 967ms\tremaining: 242ms\n", "16:\tlearn: 0.1619736\ttest: 0.1454357\tbest: 0.1452219 (15)\ttotal: 1.03s\tremaining: 183ms\n", "17:\tlearn: 0.1619147\ttest: 0.1453835\tbest: 0.1452219 (15)\ttotal: 1.1s\tremaining: 123ms\n", "18:\tlearn: 0.1619084\ttest: 0.1452848\tbest: 0.1452219 (15)\ttotal: 1.15s\tremaining: 60.8ms\n", "19:\tlearn: 0.1613569\ttest: 0.1452973\tbest: 0.1452219 (15)\ttotal: 1.22s\tremaining: 0us\n", "\n", "bestTest = 0.1452219235\n", "bestIteration = 15\n", "\n", "Shrink model to first 16 iterations.\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 57, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#!rm 'catboost_info/snapshot.bkp'\n", "from catboost import CatBoostClassifier\n", "model = CatBoostClassifier(\n", " iterations=20,\n", " save_snapshot=True,\n", " snapshot_file='snapshot.bkp',\n", " snapshot_interval=1,\n", " random_seed=43\n", ")\n", "model.fit(\n", " X_train, y_train,\n", " eval_set=(X_validation, y_validation),\n", " cat_features=cat_features,\n", " verbose=True\n", ")" ] }, { "cell_type": "code", "execution_count": 58, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Learning rate set to 0.218957\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "c7042c21faab470497825be24e3581d8", "version_major": 2, "version_minor": 0 }, "text/plain": [ "MetricVisualizer(layout=Layout(align_self=u'stretch', height=u'500px'))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "20:\tlearn: 0.1612152\ttest: 0.1452297\tbest: 0.1452219 (15)\ttotal: 1.28s\tremaining: 11.5s\n", "21:\tlearn: 0.1612042\ttest: 0.1452362\tbest: 0.1452219 (15)\ttotal: 1.33s\tremaining: 10.6s\n", "22:\tlearn: 0.1610427\ttest: 0.1452320\tbest: 0.1452219 (15)\ttotal: 1.4s\tremaining: 11.1s\n", "23:\tlearn: 0.1602574\ttest: 0.1443945\tbest: 0.1443945 (23)\ttotal: 1.49s\tremaining: 11.9s\n", "24:\tlearn: 0.1600288\ttest: 0.1443164\tbest: 0.1443164 (24)\ttotal: 1.56s\tremaining: 12.1s\n", "25:\tlearn: 0.1589975\ttest: 0.1435700\tbest: 0.1435700 (25)\ttotal: 1.63s\tremaining: 12.1s\n", "26:\tlearn: 0.1589964\ttest: 0.1435730\tbest: 0.1435700 (25)\ttotal: 1.69s\tremaining: 11.6s\n", "27:\tlearn: 0.1588313\ttest: 0.1433549\tbest: 0.1433549 (27)\ttotal: 1.75s\tremaining: 11.5s\n", "28:\tlearn: 0.1588014\ttest: 0.1433848\tbest: 0.1433549 (27)\ttotal: 1.81s\tremaining: 11.3s\n", "29:\tlearn: 0.1584245\ttest: 0.1431081\tbest: 0.1431081 (29)\ttotal: 1.88s\tremaining: 11.2s\n", "30:\tlearn: 0.1583421\ttest: 0.1431004\tbest: 0.1431004 (30)\ttotal: 1.94s\tremaining: 11.1s\n", "31:\tlearn: 0.1583295\ttest: 0.1430967\tbest: 0.1430967 (31)\ttotal: 2s\tremaining: 10.9s\n", "32:\tlearn: 0.1583294\ttest: 0.1430988\tbest: 0.1430967 (31)\ttotal: 2.04s\tremaining: 10.6s\n", "33:\tlearn: 0.1582839\ttest: 0.1431464\tbest: 0.1430967 (31)\ttotal: 2.1s\tremaining: 10.5s\n", "34:\tlearn: 0.1582705\ttest: 0.1430595\tbest: 0.1430595 (34)\ttotal: 2.14s\tremaining: 10.2s\n", "35:\tlearn: 0.1582696\ttest: 0.1430571\tbest: 0.1430571 (35)\ttotal: 2.19s\tremaining: 9.97s\n", "36:\tlearn: 0.1577099\ttest: 0.1426513\tbest: 0.1426513 (36)\ttotal: 2.26s\tremaining: 10s\n", "37:\tlearn: 0.1577047\ttest: 0.1426538\tbest: 0.1426513 (36)\ttotal: 2.34s\tremaining: 10.1s\n", "38:\tlearn: 0.1576397\ttest: 0.1426637\tbest: 0.1426513 (36)\ttotal: 2.39s\tremaining: 9.97s\n", "39:\tlearn: 0.1576390\ttest: 0.1426667\tbest: 0.1426513 (36)\ttotal: 2.44s\tremaining: 9.81s\n", "40:\tlearn: 0.1576366\ttest: 0.1426616\tbest: 0.1426513 (36)\ttotal: 2.49s\tremaining: 9.63s\n", "41:\tlearn: 0.1575441\ttest: 0.1427510\tbest: 0.1426513 (36)\ttotal: 2.55s\tremaining: 9.59s\n", "42:\tlearn: 0.1573020\ttest: 0.1424677\tbest: 0.1424677 (42)\ttotal: 2.62s\tremaining: 9.61s\n", "43:\tlearn: 0.1572261\ttest: 0.1425183\tbest: 0.1424677 (42)\ttotal: 2.69s\tremaining: 9.59s\n", "44:\tlearn: 0.1571721\ttest: 0.1424962\tbest: 0.1424677 (42)\ttotal: 2.75s\tremaining: 9.48s\n", "45:\tlearn: 0.1567994\ttest: 0.1423557\tbest: 0.1423557 (45)\ttotal: 2.81s\tremaining: 9.47s\n", "46:\tlearn: 0.1565110\ttest: 0.1418114\tbest: 0.1418114 (46)\ttotal: 2.88s\tremaining: 9.41s\n", "47:\tlearn: 0.1557751\ttest: 0.1414036\tbest: 0.1414036 (47)\ttotal: 2.94s\tremaining: 9.35s\n", "48:\tlearn: 0.1557492\ttest: 0.1413917\tbest: 0.1413917 (48)\ttotal: 3s\tremaining: 9.27s\n", "49:\tlearn: 0.1556371\ttest: 0.1415897\tbest: 0.1413917 (48)\ttotal: 3.06s\tremaining: 9.21s\n", "50:\tlearn: 0.1555799\ttest: 0.1416534\tbest: 0.1413917 (48)\ttotal: 3.12s\tremaining: 9.17s\n", "51:\tlearn: 0.1553364\ttest: 0.1415682\tbest: 0.1413917 (48)\ttotal: 3.19s\tremaining: 9.14s\n", "52:\tlearn: 0.1549614\ttest: 0.1414633\tbest: 0.1413917 (48)\ttotal: 3.26s\tremaining: 9.12s\n", "53:\tlearn: 0.1549144\ttest: 0.1414808\tbest: 0.1413917 (48)\ttotal: 3.32s\tremaining: 9.02s\n", "54:\tlearn: 0.1542091\ttest: 0.1411187\tbest: 0.1411187 (54)\ttotal: 3.4s\tremaining: 9.06s\n", "55:\tlearn: 0.1540154\ttest: 0.1411603\tbest: 0.1411187 (54)\ttotal: 3.47s\tremaining: 9.03s\n", "56:\tlearn: 0.1539392\ttest: 0.1412245\tbest: 0.1411187 (54)\ttotal: 3.54s\tremaining: 8.99s\n", "57:\tlearn: 0.1538993\ttest: 0.1411606\tbest: 0.1411187 (54)\ttotal: 3.61s\tremaining: 8.93s\n", "58:\tlearn: 0.1538640\ttest: 0.1411852\tbest: 0.1411187 (54)\ttotal: 3.67s\tremaining: 8.86s\n", "59:\tlearn: 0.1537952\ttest: 0.1411267\tbest: 0.1411187 (54)\ttotal: 3.73s\tremaining: 8.8s\n", "60:\tlearn: 0.1537380\ttest: 0.1411907\tbest: 0.1411187 (54)\ttotal: 3.8s\tremaining: 8.76s\n", "61:\tlearn: 0.1536350\ttest: 0.1409801\tbest: 0.1409801 (61)\ttotal: 3.88s\tremaining: 8.74s\n", "62:\tlearn: 0.1535942\ttest: 0.1410172\tbest: 0.1409801 (61)\ttotal: 3.93s\tremaining: 8.66s\n", "63:\tlearn: 0.1535253\ttest: 0.1410033\tbest: 0.1409801 (61)\ttotal: 4s\tremaining: 8.62s\n", "64:\tlearn: 0.1534290\ttest: 0.1410210\tbest: 0.1409801 (61)\ttotal: 4.07s\tremaining: 8.55s\n", "65:\tlearn: 0.1534215\ttest: 0.1410008\tbest: 0.1409801 (61)\ttotal: 4.14s\tremaining: 8.53s\n", "66:\tlearn: 0.1532798\ttest: 0.1408475\tbest: 0.1408475 (66)\ttotal: 4.22s\tremaining: 8.49s\n", "67:\tlearn: 0.1530738\ttest: 0.1409062\tbest: 0.1408475 (66)\ttotal: 4.28s\tremaining: 8.44s\n", "68:\tlearn: 0.1529922\ttest: 0.1409187\tbest: 0.1408475 (66)\ttotal: 4.35s\tremaining: 8.38s\n", "69:\tlearn: 0.1529770\ttest: 0.1409706\tbest: 0.1408475 (66)\ttotal: 4.44s\tremaining: 8.38s\n", "70:\tlearn: 0.1529625\ttest: 0.1409251\tbest: 0.1408475 (66)\ttotal: 4.5s\tremaining: 8.3s\n", "71:\tlearn: 0.1529499\ttest: 0.1409476\tbest: 0.1408475 (66)\ttotal: 4.57s\tremaining: 8.24s\n", "72:\tlearn: 0.1528810\ttest: 0.1409514\tbest: 0.1408475 (66)\ttotal: 4.63s\tremaining: 8.18s\n", "73:\tlearn: 0.1528579\ttest: 0.1409500\tbest: 0.1408475 (66)\ttotal: 4.69s\tremaining: 8.1s\n", "74:\tlearn: 0.1528168\ttest: 0.1409697\tbest: 0.1408475 (66)\ttotal: 4.75s\tremaining: 8.04s\n", "75:\tlearn: 0.1526691\ttest: 0.1409859\tbest: 0.1408475 (66)\ttotal: 4.82s\tremaining: 7.98s\n", "76:\tlearn: 0.1525093\ttest: 0.1409349\tbest: 0.1408475 (66)\ttotal: 4.88s\tremaining: 7.91s\n", "77:\tlearn: 0.1524918\ttest: 0.1409958\tbest: 0.1408475 (66)\ttotal: 4.94s\tremaining: 7.84s\n", "78:\tlearn: 0.1523623\ttest: 0.1411616\tbest: 0.1408475 (66)\ttotal: 5.01s\tremaining: 7.78s\n", "79:\tlearn: 0.1520972\ttest: 0.1411054\tbest: 0.1408475 (66)\ttotal: 5.07s\tremaining: 7.71s\n", "80:\tlearn: 0.1520786\ttest: 0.1411760\tbest: 0.1408475 (66)\ttotal: 5.14s\tremaining: 7.65s\n", "81:\tlearn: 0.1520778\ttest: 0.1411706\tbest: 0.1408475 (66)\ttotal: 5.19s\tremaining: 7.56s\n", "82:\tlearn: 0.1520648\ttest: 0.1412223\tbest: 0.1408475 (66)\ttotal: 5.26s\tremaining: 7.52s\n", "83:\tlearn: 0.1515509\ttest: 0.1408977\tbest: 0.1408475 (66)\ttotal: 5.33s\tremaining: 7.46s\n", "84:\tlearn: 0.1515241\ttest: 0.1409170\tbest: 0.1408475 (66)\ttotal: 5.41s\tremaining: 7.41s\n", "85:\tlearn: 0.1507357\ttest: 0.1406431\tbest: 0.1406431 (85)\ttotal: 5.5s\tremaining: 7.4s\n", "86:\tlearn: 0.1505739\ttest: 0.1406102\tbest: 0.1406102 (86)\ttotal: 5.57s\tremaining: 7.34s\n", "87:\tlearn: 0.1505528\ttest: 0.1406036\tbest: 0.1406036 (87)\ttotal: 5.65s\tremaining: 7.3s\n", "88:\tlearn: 0.1502959\ttest: 0.1407175\tbest: 0.1406036 (87)\ttotal: 5.72s\tremaining: 7.25s\n", "89:\tlearn: 0.1501328\ttest: 0.1404297\tbest: 0.1404297 (89)\ttotal: 5.79s\tremaining: 7.19s\n", "90:\tlearn: 0.1499854\ttest: 0.1403098\tbest: 0.1403098 (90)\ttotal: 5.86s\tremaining: 7.13s\n", "91:\tlearn: 0.1499249\ttest: 0.1401775\tbest: 0.1401775 (91)\ttotal: 5.93s\tremaining: 7.07s\n", "92:\tlearn: 0.1497086\ttest: 0.1401750\tbest: 0.1401750 (92)\ttotal: 6s\tremaining: 7s\n", "93:\tlearn: 0.1496704\ttest: 0.1401246\tbest: 0.1401246 (93)\ttotal: 6.06s\tremaining: 6.95s\n", "94:\tlearn: 0.1496610\ttest: 0.1401833\tbest: 0.1401246 (93)\ttotal: 6.13s\tremaining: 6.89s\n", "95:\tlearn: 0.1495849\ttest: 0.1401673\tbest: 0.1401246 (93)\ttotal: 6.21s\tremaining: 6.83s\n", "96:\tlearn: 0.1494667\ttest: 0.1400086\tbest: 0.1400086 (96)\ttotal: 6.27s\tremaining: 6.77s\n", "97:\tlearn: 0.1493525\ttest: 0.1401855\tbest: 0.1400086 (96)\ttotal: 6.34s\tremaining: 6.71s\n", "98:\tlearn: 0.1489917\ttest: 0.1401021\tbest: 0.1400086 (96)\ttotal: 6.41s\tremaining: 6.64s\n", "99:\tlearn: 0.1489125\ttest: 0.1401268\tbest: 0.1400086 (96)\ttotal: 6.47s\tremaining: 6.57s\n", "100:\tlearn: 0.1489016\ttest: 0.1401160\tbest: 0.1400086 (96)\ttotal: 6.57s\tremaining: 6.54s\n", "101:\tlearn: 0.1487716\ttest: 0.1400791\tbest: 0.1400086 (96)\ttotal: 6.64s\tremaining: 6.48s\n", "102:\tlearn: 0.1487237\ttest: 0.1400954\tbest: 0.1400086 (96)\ttotal: 6.71s\tremaining: 6.42s\n", "103:\tlearn: 0.1486277\ttest: 0.1401909\tbest: 0.1400086 (96)\ttotal: 6.78s\tremaining: 6.36s\n", "104:\tlearn: 0.1486153\ttest: 0.1401898\tbest: 0.1400086 (96)\ttotal: 6.85s\tremaining: 6.29s\n", "105:\tlearn: 0.1485328\ttest: 0.1401209\tbest: 0.1400086 (96)\ttotal: 6.92s\tremaining: 6.23s\n", "106:\tlearn: 0.1484761\ttest: 0.1401947\tbest: 0.1400086 (96)\ttotal: 6.98s\tremaining: 6.16s\n", "107:\tlearn: 0.1484161\ttest: 0.1402889\tbest: 0.1400086 (96)\ttotal: 7.05s\tremaining: 6.1s\n", "108:\tlearn: 0.1483455\ttest: 0.1404464\tbest: 0.1400086 (96)\ttotal: 7.12s\tremaining: 6.03s\n", "109:\tlearn: 0.1482798\ttest: 0.1403834\tbest: 0.1400086 (96)\ttotal: 7.18s\tremaining: 5.96s\n", "110:\tlearn: 0.1482755\ttest: 0.1403591\tbest: 0.1400086 (96)\ttotal: 7.24s\tremaining: 5.89s\n", "111:\tlearn: 0.1481230\ttest: 0.1405596\tbest: 0.1400086 (96)\ttotal: 7.32s\tremaining: 5.84s\n", "112:\tlearn: 0.1477729\ttest: 0.1405197\tbest: 0.1400086 (96)\ttotal: 7.4s\tremaining: 5.79s\n", "113:\tlearn: 0.1477185\ttest: 0.1405934\tbest: 0.1400086 (96)\ttotal: 7.47s\tremaining: 5.72s\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "114:\tlearn: 0.1474968\ttest: 0.1405011\tbest: 0.1400086 (96)\ttotal: 7.54s\tremaining: 5.66s\n", "115:\tlearn: 0.1474954\ttest: 0.1405148\tbest: 0.1400086 (96)\ttotal: 7.63s\tremaining: 5.61s\n", "116:\tlearn: 0.1474155\ttest: 0.1405253\tbest: 0.1400086 (96)\ttotal: 7.69s\tremaining: 5.54s\n", "117:\tlearn: 0.1473274\ttest: 0.1405328\tbest: 0.1400086 (96)\ttotal: 7.77s\tremaining: 5.48s\n", "118:\tlearn: 0.1472799\ttest: 0.1405519\tbest: 0.1400086 (96)\ttotal: 7.83s\tremaining: 5.42s\n", "119:\tlearn: 0.1471431\ttest: 0.1402404\tbest: 0.1400086 (96)\ttotal: 7.91s\tremaining: 5.35s\n", "120:\tlearn: 0.1470829\ttest: 0.1403126\tbest: 0.1400086 (96)\ttotal: 7.98s\tremaining: 5.29s\n", "121:\tlearn: 0.1470043\ttest: 0.1402714\tbest: 0.1400086 (96)\ttotal: 8.04s\tremaining: 5.22s\n", "122:\tlearn: 0.1469235\ttest: 0.1403153\tbest: 0.1400086 (96)\ttotal: 8.12s\tremaining: 5.16s\n", "123:\tlearn: 0.1469211\ttest: 0.1403098\tbest: 0.1400086 (96)\ttotal: 8.18s\tremaining: 5.09s\n", "124:\tlearn: 0.1467363\ttest: 0.1401837\tbest: 0.1400086 (96)\ttotal: 8.25s\tremaining: 5.02s\n", "125:\tlearn: 0.1467291\ttest: 0.1401725\tbest: 0.1400086 (96)\ttotal: 8.31s\tremaining: 4.96s\n", "126:\tlearn: 0.1466647\ttest: 0.1402149\tbest: 0.1400086 (96)\ttotal: 8.38s\tremaining: 4.89s\n", "127:\tlearn: 0.1465627\ttest: 0.1400712\tbest: 0.1400086 (96)\ttotal: 8.45s\tremaining: 4.82s\n", "128:\tlearn: 0.1465278\ttest: 0.1401424\tbest: 0.1400086 (96)\ttotal: 8.52s\tremaining: 4.75s\n", "129:\tlearn: 0.1465071\ttest: 0.1400993\tbest: 0.1400086 (96)\ttotal: 8.58s\tremaining: 4.69s\n", "130:\tlearn: 0.1464897\ttest: 0.1400964\tbest: 0.1400086 (96)\ttotal: 8.66s\tremaining: 4.63s\n", "131:\tlearn: 0.1464020\ttest: 0.1401162\tbest: 0.1400086 (96)\ttotal: 8.73s\tremaining: 4.56s\n", "132:\tlearn: 0.1462656\ttest: 0.1400077\tbest: 0.1400077 (132)\ttotal: 8.8s\tremaining: 4.49s\n", "133:\tlearn: 0.1460190\ttest: 0.1398864\tbest: 0.1398864 (133)\ttotal: 8.86s\tremaining: 4.42s\n", "134:\tlearn: 0.1459637\ttest: 0.1398490\tbest: 0.1398490 (134)\ttotal: 8.92s\tremaining: 4.35s\n", "135:\tlearn: 0.1457154\ttest: 0.1395519\tbest: 0.1395519 (135)\ttotal: 9s\tremaining: 4.29s\n", "136:\tlearn: 0.1457117\ttest: 0.1395426\tbest: 0.1395426 (136)\ttotal: 9.06s\tremaining: 4.23s\n", "137:\tlearn: 0.1456945\ttest: 0.1395723\tbest: 0.1395426 (136)\ttotal: 9.14s\tremaining: 4.16s\n", "138:\tlearn: 0.1456748\ttest: 0.1395792\tbest: 0.1395426 (136)\ttotal: 9.2s\tremaining: 4.09s\n", "139:\tlearn: 0.1455239\ttest: 0.1398163\tbest: 0.1395426 (136)\ttotal: 9.28s\tremaining: 4.03s\n", "140:\tlearn: 0.1455034\ttest: 0.1398620\tbest: 0.1395426 (136)\ttotal: 9.37s\tremaining: 3.97s\n", "141:\tlearn: 0.1454291\ttest: 0.1400755\tbest: 0.1395426 (136)\ttotal: 9.43s\tremaining: 3.9s\n", "142:\tlearn: 0.1454196\ttest: 0.1401048\tbest: 0.1395426 (136)\ttotal: 9.49s\tremaining: 3.84s\n", "143:\tlearn: 0.1453446\ttest: 0.1400991\tbest: 0.1395426 (136)\ttotal: 9.56s\tremaining: 3.77s\n", "144:\tlearn: 0.1453092\ttest: 0.1400823\tbest: 0.1395426 (136)\ttotal: 9.63s\tremaining: 3.7s\n", "145:\tlearn: 0.1452807\ttest: 0.1401418\tbest: 0.1395426 (136)\ttotal: 9.71s\tremaining: 3.64s\n", "146:\tlearn: 0.1452572\ttest: 0.1401681\tbest: 0.1395426 (136)\ttotal: 9.77s\tremaining: 3.57s\n", "147:\tlearn: 0.1452437\ttest: 0.1401694\tbest: 0.1395426 (136)\ttotal: 9.82s\tremaining: 3.5s\n", "148:\tlearn: 0.1452409\ttest: 0.1401696\tbest: 0.1395426 (136)\ttotal: 9.88s\tremaining: 3.43s\n", "149:\tlearn: 0.1452313\ttest: 0.1401945\tbest: 0.1395426 (136)\ttotal: 9.95s\tremaining: 3.36s\n", "150:\tlearn: 0.1452248\ttest: 0.1402280\tbest: 0.1395426 (136)\ttotal: 10s\tremaining: 3.29s\n", "151:\tlearn: 0.1451256\ttest: 0.1401826\tbest: 0.1395426 (136)\ttotal: 10.1s\tremaining: 3.22s\n", "152:\tlearn: 0.1451233\ttest: 0.1401831\tbest: 0.1395426 (136)\ttotal: 10.1s\tremaining: 3.15s\n", "153:\tlearn: 0.1451067\ttest: 0.1402200\tbest: 0.1395426 (136)\ttotal: 10.2s\tremaining: 3.08s\n", "154:\tlearn: 0.1450766\ttest: 0.1402811\tbest: 0.1395426 (136)\ttotal: 10.3s\tremaining: 3.02s\n", "155:\tlearn: 0.1450670\ttest: 0.1402883\tbest: 0.1395426 (136)\ttotal: 10.3s\tremaining: 2.95s\n", "156:\tlearn: 0.1450589\ttest: 0.1402922\tbest: 0.1395426 (136)\ttotal: 10.4s\tremaining: 2.88s\n", "157:\tlearn: 0.1450369\ttest: 0.1404510\tbest: 0.1395426 (136)\ttotal: 10.5s\tremaining: 2.81s\n", "158:\tlearn: 0.1450249\ttest: 0.1404327\tbest: 0.1395426 (136)\ttotal: 10.5s\tremaining: 2.75s\n", "159:\tlearn: 0.1450211\ttest: 0.1404312\tbest: 0.1395426 (136)\ttotal: 10.6s\tremaining: 2.68s\n", "160:\tlearn: 0.1449969\ttest: 0.1404053\tbest: 0.1395426 (136)\ttotal: 10.7s\tremaining: 2.61s\n", "161:\tlearn: 0.1449487\ttest: 0.1403818\tbest: 0.1395426 (136)\ttotal: 10.8s\tremaining: 2.55s\n", "162:\tlearn: 0.1449365\ttest: 0.1403904\tbest: 0.1395426 (136)\ttotal: 10.8s\tremaining: 2.49s\n", "163:\tlearn: 0.1447393\ttest: 0.1403556\tbest: 0.1395426 (136)\ttotal: 10.9s\tremaining: 2.42s\n", "164:\tlearn: 0.1447223\ttest: 0.1403308\tbest: 0.1395426 (136)\ttotal: 11s\tremaining: 2.35s\n", "165:\tlearn: 0.1446902\ttest: 0.1402935\tbest: 0.1395426 (136)\ttotal: 11s\tremaining: 2.29s\n", "166:\tlearn: 0.1446412\ttest: 0.1402642\tbest: 0.1395426 (136)\ttotal: 11.1s\tremaining: 2.22s\n", "167:\tlearn: 0.1445740\ttest: 0.1403223\tbest: 0.1395426 (136)\ttotal: 11.2s\tremaining: 2.15s\n", "168:\tlearn: 0.1445716\ttest: 0.1403393\tbest: 0.1395426 (136)\ttotal: 11.2s\tremaining: 2.08s\n", "169:\tlearn: 0.1445616\ttest: 0.1403364\tbest: 0.1395426 (136)\ttotal: 11.3s\tremaining: 2.02s\n", "170:\tlearn: 0.1445163\ttest: 0.1403374\tbest: 0.1395426 (136)\ttotal: 11.4s\tremaining: 1.95s\n", "171:\tlearn: 0.1444992\ttest: 0.1403069\tbest: 0.1395426 (136)\ttotal: 11.4s\tremaining: 1.88s\n", "172:\tlearn: 0.1444212\ttest: 0.1404872\tbest: 0.1395426 (136)\ttotal: 11.5s\tremaining: 1.81s\n", "173:\tlearn: 0.1443616\ttest: 0.1404046\tbest: 0.1395426 (136)\ttotal: 11.6s\tremaining: 1.75s\n", "174:\tlearn: 0.1441623\ttest: 0.1404320\tbest: 0.1395426 (136)\ttotal: 11.6s\tremaining: 1.68s\n", "175:\tlearn: 0.1441596\ttest: 0.1404399\tbest: 0.1395426 (136)\ttotal: 11.7s\tremaining: 1.61s\n", "176:\tlearn: 0.1441157\ttest: 0.1403950\tbest: 0.1395426 (136)\ttotal: 11.7s\tremaining: 1.54s\n", "177:\tlearn: 0.1438309\ttest: 0.1402648\tbest: 0.1395426 (136)\ttotal: 11.8s\tremaining: 1.48s\n", "178:\tlearn: 0.1437287\ttest: 0.1402467\tbest: 0.1395426 (136)\ttotal: 11.9s\tremaining: 1.41s\n", "179:\tlearn: 0.1437173\ttest: 0.1402203\tbest: 0.1395426 (136)\ttotal: 12s\tremaining: 1.34s\n", "180:\tlearn: 0.1436975\ttest: 0.1402458\tbest: 0.1395426 (136)\ttotal: 12s\tremaining: 1.28s\n", "181:\tlearn: 0.1436947\ttest: 0.1402477\tbest: 0.1395426 (136)\ttotal: 12.1s\tremaining: 1.21s\n", "182:\tlearn: 0.1435053\ttest: 0.1401582\tbest: 0.1395426 (136)\ttotal: 12.2s\tremaining: 1.14s\n", "183:\tlearn: 0.1433447\ttest: 0.1402034\tbest: 0.1395426 (136)\ttotal: 12.2s\tremaining: 1.07s\n", "184:\tlearn: 0.1433105\ttest: 0.1402677\tbest: 0.1395426 (136)\ttotal: 12.3s\tremaining: 1.01s\n", "185:\tlearn: 0.1432248\ttest: 0.1403297\tbest: 0.1395426 (136)\ttotal: 12.4s\tremaining: 943ms\n", "186:\tlearn: 0.1431444\ttest: 0.1402979\tbest: 0.1395426 (136)\ttotal: 12.5s\tremaining: 875ms\n", "187:\tlearn: 0.1431310\ttest: 0.1404082\tbest: 0.1395426 (136)\ttotal: 12.5s\tremaining: 808ms\n", "188:\tlearn: 0.1431025\ttest: 0.1405215\tbest: 0.1395426 (136)\ttotal: 12.6s\tremaining: 740ms\n", "189:\tlearn: 0.1430438\ttest: 0.1404832\tbest: 0.1395426 (136)\ttotal: 12.7s\tremaining: 673ms\n", "190:\tlearn: 0.1429858\ttest: 0.1404577\tbest: 0.1395426 (136)\ttotal: 12.7s\tremaining: 605ms\n", "191:\tlearn: 0.1429675\ttest: 0.1404993\tbest: 0.1395426 (136)\ttotal: 12.8s\tremaining: 537ms\n", "192:\tlearn: 0.1427232\ttest: 0.1404004\tbest: 0.1395426 (136)\ttotal: 12.9s\tremaining: 471ms\n", "193:\tlearn: 0.1427034\ttest: 0.1404061\tbest: 0.1395426 (136)\ttotal: 12.9s\tremaining: 404ms\n", "194:\tlearn: 0.1426123\ttest: 0.1404027\tbest: 0.1395426 (136)\ttotal: 13s\tremaining: 337ms\n", "195:\tlearn: 0.1425407\ttest: 0.1403096\tbest: 0.1395426 (136)\ttotal: 13.1s\tremaining: 269ms\n", "196:\tlearn: 0.1425383\ttest: 0.1403008\tbest: 0.1395426 (136)\ttotal: 13.1s\tremaining: 202ms\n", "197:\tlearn: 0.1425326\ttest: 0.1402927\tbest: 0.1395426 (136)\ttotal: 13.2s\tremaining: 135ms\n", "198:\tlearn: 0.1423886\ttest: 0.1403825\tbest: 0.1395426 (136)\ttotal: 13.3s\tremaining: 67.3ms\n", "199:\tlearn: 0.1421418\ttest: 0.1405326\tbest: 0.1395426 (136)\ttotal: 13.3s\tremaining: 0us\n", "\n", "bestTest = 0.1395425582\n", "bestIteration = 136\n", "\n", "Shrink model to first 137 iterations.\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 58, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model = CatBoostClassifier(\n", " iterations=200,\n", " save_snapshot=True,\n", " snapshot_file='snapshot.bkp',\n", " snapshot_interval=1,\n", " random_seed=43\n", ")\n", "model.fit(\n", " X_train, y_train,\n", " eval_set=(X_validation, y_validation),\n", " cat_features=cat_features,\n", " verbose=True,\n", " plot=True\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Model predictions" ] }, { "cell_type": "code", "execution_count": 59, "metadata": {}, "outputs": [], "source": [ "? model.predict_proba" ] }, { "cell_type": "code", "execution_count": 60, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[0.0381 0.9619]\n", " [0.0072 0.9928]\n", " [0.015 0.985 ]\n", " ...\n", " [0.0099 0.9901]\n", " [0.0159 0.9841]\n", " [0.0243 0.9757]]\n" ] } ], "source": [ "print(model.predict_proba(X=X_validation))" ] }, { "cell_type": "code", "execution_count": 61, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[1. 1. 1. ... 1. 1. 1.]\n" ] } ], "source": [ "print(model.predict(data=X_validation))" ] }, { "cell_type": "code", "execution_count": 62, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[3.2289 4.9216 4.1868 ... 4.6093 4.1257 3.693 ]\n" ] } ], "source": [ "raw_pred = model.predict(data=X_validation, prediction_type='RawFormulaVal')\n", "print(raw_pred)" ] }, { "cell_type": "code", "execution_count": 63, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[0.9619 0.9928 0.985 ... 0.9901 0.9841 0.9757]\n" ] } ], "source": [ "import math\n", "def sigmoid(x):\n", " return 1 / (1 + math.exp(-x))\n", "probabilities = [sigmoid(x) for x in raw_pred]\n", "print(np.array(probabilities))" ] }, { "cell_type": "code", "execution_count": 64, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[0.0381 0.9619]\n", " [0.0072 0.9928]\n", " [0.015 0.985 ]\n", " ...\n", " [0.0099 0.9901]\n", " [0.0159 0.9841]\n", " [0.0243 0.9757]]\n" ] } ], "source": [ "X_prepared = X_validation.values.astype(str).astype(object)\n", "# For FeaturesData class categorial features must have type str\n", "\n", "fast_predictions = model.predict_proba(X=FeaturesData(cat_feature_data=X_prepared, \n", " cat_feature_names=list(X_validation)))\n", "\n", "print(fast_predictions)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Staged prediction" ] }, { "cell_type": "code", "execution_count": 65, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Iteration 0, predictions:\n", "[[0.3147 0.6853]\n", " [0.3119 0.6881]\n", " [0.3119 0.6881]\n", " ...\n", " [0.2477 0.7523]\n", " [0.2477 0.7523]\n", " [0.3119 0.6881]]\n", "Iteration 1, predictions:\n", "[[0.2677 0.7323]\n", " [0.2415 0.7585]\n", " [0.3053 0.6947]\n", " ...\n", " [0.2419 0.7581]\n", " [0.1878 0.8122]\n", " [0.3053 0.6947]]\n", "Iteration 2, predictions:\n", "[[0.2348 0.7652]\n", " [0.2108 0.7892]\n", " [0.2694 0.7306]\n", " ...\n", " [0.2112 0.7888]\n", " [0.1883 0.8117]\n", " [0.2694 0.7306]]\n" ] } ], "source": [ "predictions_gen = model.staged_predict_proba(data=X_validation, ntree_start=2, ntree_end=8, eval_period=2)\n", "for iteration, predictions in enumerate(predictions_gen):\n", " print('Iteration ' + str(iteration) + ', predictions:')\n", " print(predictions)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Solving MultiClassification problem" ] }, { "cell_type": "code", "execution_count": 66, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "267f049977264b0581dbacd8f56aa86b", "version_major": 2, "version_minor": 0 }, "text/plain": [ "MetricVisualizer(layout=Layout(align_self=u'stretch', height=u'500px'))" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "" ] }, "execution_count": 66, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from catboost import CatBoostClassifier\n", "model = CatBoostClassifier(\n", " iterations=150,\n", " random_seed=43,\n", " loss_function='MultiClass'\n", " #loss_function='MultiClassOneVsAll'\n", ")\n", "model.fit(\n", " X_train, y_train,\n", " cat_features=cat_features,\n", " eval_set=(X_validation, y_validation),\n", " verbose=False,\n", " plot=True\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Metric evaluation on a new dataset" ] }, { "cell_type": "code", "execution_count": 67, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0:\tlearn: 0.6569860\ttotal: 41.3ms\tremaining: 8.21s\n", "50:\tlearn: 0.1950260\ttotal: 2.58s\tremaining: 7.52s\n", "100:\tlearn: 0.1700584\ttotal: 5.3s\tremaining: 5.19s\n", "150:\tlearn: 0.1641016\ttotal: 8.09s\tremaining: 2.62s\n", "199:\tlearn: 0.1604074\ttotal: 10.8s\tremaining: 0us\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 67, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model = CatBoostClassifier(\n", " random_seed=63,\n", " iterations=200,\n", " learning_rate=0.03,\n", ")\n", "model.fit(\n", " X_train, y_train,\n", " cat_features=cat_features,\n", " verbose=50\n", ")" ] }, { "cell_type": "code", "execution_count": 68, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "metrics = model.eval_metrics(data=pool1, \n", " metrics=['Logloss','AUC'],\n", " ntree_start=0,\n", " ntree_end=0, \n", " eval_period=1,\n", " plot=True)" ] }, { "cell_type": "code", "execution_count": 69, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "AUC values:\n", "[0.4998 0.538 0.5504 0.5888 0.6503 0.6487 0.6487 0.6601 0.6601 0.6612\n", " 0.6614 0.67 0.6699 0.6697 0.6697 0.6698 0.6698 0.6698 0.6698 0.7265\n", " 0.7338 0.7376 0.7478 0.7478 0.7506 0.7591 0.7642 0.7877 0.8148 0.8265\n", " 0.8353 0.8459 0.8533 0.8564 0.8573 0.8748 0.8874 0.893 0.8948 0.898\n", " 0.9003 0.9052 0.9122 0.9159 0.92 0.9204 0.9209 0.9246 0.9262 0.9266\n", " 0.9279 0.9303 0.9311 0.9312 0.9327 0.9329 0.9337 0.9341 0.9341 0.9349\n", " 0.9354 0.9365 0.9391 0.9411 0.9424 0.9435 0.9446 0.9458 0.9463 0.9471\n", " 0.9478 0.9481 0.9482 0.9483 0.9486 0.9497 0.951 0.9514 0.9515 0.9519\n", " 0.9522 0.9526 0.953 0.9537 0.9543 0.9544 0.9545 0.9548 0.9548 0.9548\n", " 0.9548 0.9548 0.955 0.9551 0.9553 0.9554 0.9557 0.9557 0.9557 0.9557\n", " 0.956 0.9566 0.9566 0.9569 0.9568 0.957 0.9569 0.9572 0.9573 0.9575\n", " 0.9576 0.9576 0.9577 0.958 0.9587 0.9594 0.9595 0.9594 0.9601 0.9609\n", " 0.9609 0.9608 0.9608 0.9612 0.9616 0.9618 0.9622 0.9623 0.9623 0.9625\n", " 0.9625 0.9628 0.9629 0.9629 0.9629 0.963 0.9631 0.9632 0.9635 0.9636\n", " 0.9637 0.9638 0.9638 0.964 0.9641 0.9643 0.9645 0.9645 0.9645 0.9647\n", " 0.9647 0.9647 0.9647 0.9648 0.9648 0.9648 0.965 0.9652 0.9652 0.9653\n", " 0.9653 0.9654 0.9655 0.9655 0.9655 0.9656 0.9656 0.9657 0.9657 0.9657\n", " 0.9658 0.9658 0.9659 0.9659 0.966 0.966 0.966 0.966 0.966 0.966\n", " 0.966 0.966 0.9661 0.9662 0.9663 0.9663 0.9663 0.9663 0.9663 0.9663\n", " 0.9665 0.9664 0.9666 0.9666 0.9666 0.9666 0.9666 0.9667 0.9667 0.9666]\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "f6e28c97b9564447b6f30be735949b0b", "version_major": 2, "version_minor": 0 }, "text/plain": [ "MetricVisualizer(layout=Layout(align_self=u'stretch', height=u'500px'))" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "print('AUC values:')\n", "print(np.array(metrics['AUC']))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Saving the model" ] }, { "cell_type": "code", "execution_count": 70, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 70, "metadata": {}, "output_type": "execute_result" } ], "source": [ "my_best_model = CatBoostClassifier(iterations=10)\n", "my_best_model.fit(\n", " X_train, y_train,\n", " eval_set=(X_validation, y_validation),\n", " cat_features=cat_features,\n", " verbose=False\n", ")" ] }, { "cell_type": "code", "execution_count": 71, "metadata": {}, "outputs": [], "source": [ "my_best_model.save_model('catboost_model.bin')" ] }, { "cell_type": "code", "execution_count": 72, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'loss_function': u'Logloss', 'verbose': 0, 'iterations': 10, 'logging_level': u'Silent'}\n", "0\n" ] } ], "source": [ "my_best_model.load_model('catboost_model.bin')\n", "print(my_best_model.get_params())\n", "print(my_best_model.random_seed_)" ] }, { "cell_type": "code", "execution_count": 73, "metadata": {}, "outputs": [], "source": [ "my_best_model.save_model('catboost_model.json', format='json')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Feature importances" ] }, { "cell_type": "code", "execution_count": 74, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0:\tlearn: 0.6569860\ttotal: 47.6ms\tremaining: 9.47s\n", "50:\tlearn: 0.1950260\ttotal: 2.81s\tremaining: 8.22s\n", "100:\tlearn: 0.1700584\ttotal: 5.64s\tremaining: 5.53s\n", "150:\tlearn: 0.1641016\ttotal: 8.57s\tremaining: 2.78s\n", "199:\tlearn: 0.1604074\ttotal: 11.5s\tremaining: 0us\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 74, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model = CatBoostClassifier(\n", " random_seed=63,\n", " iterations=200,\n", " learning_rate=0.03)\n", "\n", "model.fit(\n", " X_train, y_train,\n", " cat_features=cat_features,\n", " verbose=50\n", ")" ] }, { "cell_type": "code", "execution_count": 75, "metadata": {}, "outputs": [], "source": [ "fstrs = model.get_feature_importance(prettified=True)" ] }, { "cell_type": "code", "execution_count": 76, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'MGR_ID': 18.398056223683028,\n", " 'RESOURCE': 21.27625707750949,\n", " 'ROLE_CODE': 11.943230282191125,\n", " 'ROLE_DEPTNAME': 15.252806962601875,\n", " 'ROLE_FAMILY': 2.4789173515872176,\n", " 'ROLE_FAMILY_DESC': 9.984073192415531,\n", " 'ROLE_ROLLUP_1': 2.6278918788673646,\n", " 'ROLE_ROLLUP_2': 13.582536028250486,\n", " 'ROLE_TITLE': 4.456231002893895}" ] }, "execution_count": 76, "metadata": {}, "output_type": "execute_result" } ], "source": [ "{feature_name : value for feature_name, value in fstrs}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Shap values" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "https://github.com/slundberg/shap" ] }, { "cell_type": "code", "execution_count": 77, "metadata": {}, "outputs": [], "source": [ "def object_predictions(model, obj):\n", " print('Probability of class 1 = {:.4f}'.format(model.predict_proba([obj])[0][1]))\n", " print('Formula raw prediction = {:.4f}'.format(model.predict([obj], prediction_type='RawFormulaVal')[0]))\n", " " ] }, { "cell_type": "code", "execution_count": 78, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "The model has complex ctrs, so the SHAP values will be calculated approximately.\n" ] } ], "source": [ "import shap\n", "explainer = shap.TreeExplainer(model)\n", "shap_values = explainer.shap_values(Pool(X, y, cat_features=cat_features))\n" ] }, { "cell_type": "code", "execution_count": 79, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "3.3457409968646434" ] }, "execution_count": 79, "metadata": {}, "output_type": "execute_result" } ], "source": [ "explainer.expected_value" ] }, { "cell_type": "code", "execution_count": 80, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Probability of class 1 = 0.9798\n", "Formula raw prediction = 3.8820\n" ] } ], "source": [ "object_predictions(model, X.iloc[3,:])" ] }, { "cell_type": "code", "execution_count": 81, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n", "
\n", " Visualization omitted, Javascript library not loaded!
\n", " Have you run `initjs()` in this notebook? If this notebook was from another\n", " user you must also trust this notebook (File -> Trust notebook). If you are viewing\n", " this notebook on github the Javascript has been stripped for security. If you are using\n", " JupyterLab this error is because a JupyterLab extension has not yet been written.\n", "
\n", " " ], "text/plain": [ "" ] }, "execution_count": 81, "metadata": {}, "output_type": "execute_result" } ], "source": [ "shap.initjs()\n", "shap.force_plot(explainer.expected_value, shap_values[3,:], X.iloc[3,:])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The above explanation shows features each contributing to push the model output from the base value (the average model output over the training dataset we passed) to the model output. Features pushing the prediction higher are shown in red, those pushing the prediction lower are in blue" ] }, { "cell_type": "code", "execution_count": 82, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Probability of class 1 = 0.6404\n", "Formula raw prediction = 0.5772\n" ] } ], "source": [ "object_predictions(model, X.iloc[91,:])" ] }, { "cell_type": "code", "execution_count": 83, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n", "
\n", " Visualization omitted, Javascript library not loaded!
\n", " Have you run `initjs()` in this notebook? If this notebook was from another\n", " user you must also trust this notebook (File -> Trust notebook). If you are viewing\n", " this notebook on github the Javascript has been stripped for security. If you are using\n", " JupyterLab this error is because a JupyterLab extension has not yet been written.\n", "
\n", " " ], "text/plain": [ "" ] }, "execution_count": 83, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import shap\n", "shap.initjs()\n", "shap.force_plot(explainer.expected_value, shap_values[91,:], X.iloc[91,:])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To get an overview of which features are most important for a model we can plot the SHAP values of every feature for every sample. The plot below sorts features by the sum of SHAP value magnitudes over all samples, and uses SHAP values to show the distribution of the impacts each feature has on the model output. The color represents the feature value (red high, blue low)." ] }, { "cell_type": "code", "execution_count": 84, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "shap.summary_plot(shap_values, X)" ] }, { "cell_type": "code", "execution_count": 85, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "shap.summary_plot(shap_values, X, plot_type=\"bar\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Hyperparameter tunning" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Training speed" ] }, { "cell_type": "code", "execution_count": 86, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "80373e00f5534d56b89daaf137b6a7bf", "version_major": 2, "version_minor": 0 }, "text/plain": [ "MetricVisualizer(layout=Layout(align_self=u'stretch', height=u'500px'))" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "" ] }, "execution_count": 86, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from catboost import CatBoost\n", "fast_model = CatBoostClassifier(\n", " random_seed=63,\n", " iterations=150,\n", " learning_rate=0.01,\n", " boosting_type='Plain',\n", " bootstrap_type='Bernoulli',\n", " subsample=0.5,\n", " rsm=0.5,\n", " one_hot_max_size=20,\n", " leaf_estimation_iterations=2,\n", " max_ctr_complexity=1)\n", "\n", "fast_model.fit(\n", " X_train, y_train,\n", " cat_features=cat_features,\n", " verbose=False,\n", " plot=True\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Accuracy" ] }, { "cell_type": "code", "execution_count": 87, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "83c5a9d1574b4f498df146f5339d8d99", "version_major": 2, "version_minor": 0 }, "text/plain": [ "MetricVisualizer(layout=Layout(align_self=u'stretch', height=u'500px'))" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "" ] }, "execution_count": 87, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tunned_model = CatBoostClassifier(\n", " random_seed=63,\n", " iterations=1000,\n", " learning_rate=0.03,\n", " l2_leaf_reg=3,\n", " bagging_temperature=1,\n", " random_strength=1,\n", " one_hot_max_size=2,\n", " leaf_estimation_method='Newton'\n", ")\n", "tunned_model.fit(\n", " X_train, y_train,\n", " cat_features=cat_features,\n", " verbose=False,\n", " eval_set=(X_validation, y_validation),\n", " plot=True\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Training the model after parameter tunning" ] }, { "cell_type": "code", "execution_count": 88, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Learning rate set to 0.044153\n", "0:\tlearn: 0.6405883\ttotal: 37.9ms\tremaining: 38.7s\n", "100:\tlearn: 0.1584501\ttotal: 5.77s\tremaining: 52.6s\n", "200:\tlearn: 0.1518465\ttotal: 11.5s\tremaining: 46.9s\n", "300:\tlearn: 0.1486421\ttotal: 17.5s\tremaining: 41.8s\n", "400:\tlearn: 0.1463685\ttotal: 23.8s\tremaining: 36.8s\n", "500:\tlearn: 0.1446445\ttotal: 30.1s\tremaining: 31.2s\n", "600:\tlearn: 0.1431370\ttotal: 36.5s\tremaining: 25.5s\n", "700:\tlearn: 0.1417911\ttotal: 42.7s\tremaining: 19.5s\n", "800:\tlearn: 0.1408268\ttotal: 48.8s\tremaining: 13.4s\n", "900:\tlearn: 0.1398684\ttotal: 55.1s\tremaining: 7.34s\n", "1000:\tlearn: 0.1390739\ttotal: 1m 1s\tremaining: 1.23s\n", "1020:\tlearn: 0.1389030\ttotal: 1m 2s\tremaining: 0us\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 88, "metadata": {}, "output_type": "execute_result" } ], "source": [ "best_model = CatBoostClassifier(\n", " random_seed=63,\n", " iterations=int(tunned_model.tree_count_ * 1.2),\n", ")\n", "best_model.fit(\n", " X, y,\n", " cat_features=cat_features,\n", " verbose=100\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Feature evaluation" ] }, { "cell_type": "code", "execution_count": 89, "metadata": {}, "outputs": [], "source": [ "from catboost.eval.catboost_evaluation import *\n", "learn_params = {'iterations': 20, # 2000\n", " 'learning_rate': 0.5, # we set big learning_rate,\n", " # because we have small\n", " # #iterations\n", " 'random_seed': 0,\n", " 'verbose': False,\n", " 'loss_function' : 'Logloss',\n", " 'boosting_type': 'Plain'}\n", "evaluator = CatboostEvaluation('amazon/train.tsv',\n", " fold_size=10000, # <= 50% of dataset\n", " fold_count=20,\n", " column_description='amazon/train.cd',\n", " partition_random_seed=0,\n", " #working_dir=... \n", ")\n", "\n", "result = evaluator.eval_features(learn_config=learn_params,\n", " eval_metrics=['Logloss', 'Accuracy'],\n", " features_to_eval=[6, 7, 8])" ] }, { "cell_type": "code", "execution_count": 90, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
PValueScoreQuantile 0.005Quantile 0.995Decision
Features: 60.0002191.5395420.8560902.174448GOOD
Features: 80.0929630.363555-0.3215350.947839UNKNOWN
Features: 70.1353570.289296-0.2961040.847476UNKNOWN
\n", "
" ], "text/plain": [ " PValue Score Quantile 0.005 Quantile 0.995 Decision\n", "Features: 6 0.000219 1.539542 0.856090 2.174448 GOOD\n", "Features: 8 0.092963 0.363555 -0.321535 0.947839 UNKNOWN\n", "Features: 7 0.135357 0.289296 -0.296104 0.847476 UNKNOWN" ] }, "execution_count": 90, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from catboost.eval.evaluation_result import *\n", "logloss_result = result.get_metric_results('Logloss')\n", "logloss_result.get_baseline_comparison(\n", " ScoreConfig(ScoreType.Rel, overfit_iterations_info=False)\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Calculate predictions for the contest" ] }, { "cell_type": "code", "execution_count": 91, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Predictoins:\n", "[[0.3846 0.6154]\n", " [0.0154 0.9846]\n", " [0.0106 0.9894]\n", " ...\n", " [0.0077 0.9923]\n", " [0.0641 0.9359]\n", " [0.0147 0.9853]]\n" ] } ], "source": [ "X_test = test_df.drop('id', axis=1)\n", "test_pool = Pool(data=X_test, cat_features=cat_features)\n", "contest_predictions = best_model.predict_proba(test_pool)\n", "print('Predictoins:')\n", "print(contest_predictions)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Prepare the submission" ] }, { "cell_type": "code", "execution_count": 92, "metadata": {}, "outputs": [], "source": [ "f = open('submit.csv', 'w')\n", "f.write('Id,Action\\n')\n", "for idx in range(len(contest_predictions)):\n", " line = str(test_df['id'][idx]) + ',' + str(contest_predictions[idx][1]) + '\\n'\n", " f.write(line)\n", "f.close()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Submit your solution [here](https://www.kaggle.com/c/amazon-employee-access-challenge/submit).\n", "Good luck!!!" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Bonus\n", "### Solving MultiClassification problem via Ranking" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For multiclass problems with many classes sometimes it's better to solve classification problem using ranking.\n", "To do that we will build a dataset with groups.\n", "Every group will represent one object from our initial dataset.\n", "But it will have one additional categorical feature - possible class value.\n", "Target values will be equal to 1 if the class value is equal to the correct class, and 0 otherwise.\n", "Thus each group will have exactly one 1 in labels, and some zeros.\n", "You can put all possible class values in the group or you can try setting only hard negatives if there are too many labels.\n", "We'll show this approach on an example of binary classification problem." ] }, { "cell_type": "code", "execution_count": 93, "metadata": {}, "outputs": [], "source": [ "from copy import deepcopy\n", "def build_multiclass_ranking_dataset(X, y, cat_features, label_values=[0,1], start_group_id=0):\n", " ranking_matrix = []\n", " ranking_labels = []\n", " group_ids = []\n", "\n", " X_train_matrix = X.values\n", " y_train_vector = y.values\n", "\n", " for obj_idx in range(X.shape[0]):\n", " obj = list(X_train_matrix[obj_idx])\n", "\n", " for label in label_values:\n", " obj_of_given_class = deepcopy(obj)\n", " obj_of_given_class.append(label)\n", " ranking_matrix.append(obj_of_given_class)\n", " ranking_labels.append(float(y_train_vector[obj_idx] == label)) \n", " group_ids.append(start_group_id + obj_idx)\n", " \n", " final_cat_features = deepcopy(cat_features)\n", " final_cat_features.append(X.shape[1]) # new feature that we are adding should be categorical.\n", " return Pool(ranking_matrix, ranking_labels, cat_features=final_cat_features, group_id = group_ids)" ] }, { "cell_type": "code", "execution_count": 94, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "65c2a9bef94b40c3b9c0f9ec8f975280", "version_major": 2, "version_minor": 0 }, "text/plain": [ "MetricVisualizer(layout=Layout(align_self=u'stretch', height=u'500px'))" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "" ] }, "execution_count": 94, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from catboost import CatBoost\n", "params = {'iterations':150, 'learning_rate':0.01, 'l2_leaf_reg':30, 'random_seed':0, 'loss_function':'QuerySoftMax'}\n", "\n", "groupwise_train_pool = build_multiclass_ranking_dataset(X_train, y_train, cat_features, [0,1])\n", "groupwise_eval_pool = build_multiclass_ranking_dataset(X_validation, y_validation, cat_features, [0,1], X_train.shape[0])\n", "\n", "model = CatBoost(params)\n", "model.fit(\n", " X=groupwise_train_pool,\n", " verbose=False,\n", " eval_set=groupwise_eval_pool,\n", " plot=True\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Doing predictions with ranking mode" ] }, { "cell_type": "code", "execution_count": 96, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "('Raw values:', array([-1.3969, 1.4067]))\n", "('Probabilities', array([0.0571, 0.9429]))\n" ] } ], "source": [ "import math\n", "\n", "obj = list(X_validation.values[0])\n", "ratings = []\n", "for label in [0,1]:\n", " obj_with_label = deepcopy(obj)\n", " obj_with_label.append(label)\n", " rating = model.predict([obj_with_label])[0]\n", " ratings.append(rating)\n", "print('Raw values:', np.array(ratings))\n", "\n", "def soft_max(values):\n", " return [math.exp(val) / sum([math.exp(val) for val in values]) for val in values]\n", "\n", "print('Probabilities', np.array(soft_max(ratings)))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.4" }, "widgets": { "state": { "1057714ebc614324aa3ba2cf69408966": { "views": [ { "cell_index": 17 } ] }, "8381e9eed05f4a03905ae8a56d7ab4ea": { "views": [ { "cell_index": 48 } ] }, "f49684e8c5c44241bfe2c7f577f5cb41": { "views": [ { "cell_index": 53 } ] } }, "version": "1.2.0" } }, "nbformat": 4, "nbformat_minor": 2 }