{ "cells": [ { "cell_type": "markdown", "source": [ "## HW2 Question 2\r\n", "\r\n", "Your job is to predict democracy index by real GDP per capita and other demographic features. Inorder to predict the democracy index we are going to use the following models:\r\n", "1. Ridge Regression (10 points)\r\n", "2. Lasso Regression (10 points)\r\n", "3. Adaptive Lasso Regression (10 points)\r\n", "4. Elastic Net Regression (10 points)" ], "metadata": {} }, { "cell_type": "code", "execution_count": 40, "source": [ "#Required libraries\r\n", "import pandas as pd\r\n", "import asgl\r\n", "from sklearn.linear_model import ElasticNetCV, RidgeCV, LassoCV\r\n", "import logging\r\n", "from sklearn.model_selection import GridSearchCV, RepeatedKFold, cross_val_score, train_test_split\r\n", "from sklearn.metrics import mean_squared_error\r\n", "from sklearn.datasets import make_regression\r\n", "from dataclasses import dataclass\r\n", "import numpy as np\r\n", "import matplotlib.pyplot as plt\r\n" ], "outputs": [], "metadata": {} }, { "cell_type": "code", "execution_count": 54, "source": [ "income_cleansed" ], "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ " dem_ind log_gdppc log_pop age_1 age_2 age_3 age_4 \\\n", "38 0.530000 8.905374 9.933823 0.307737 0.241766 0.210866 0.151395 \n", "39 0.166667 9.016160 10.011580 0.302024 0.239600 0.207557 0.152852 \n", "40 0.833333 9.133990 10.084220 0.293686 0.245023 0.199240 0.154751 \n", "41 0.166667 9.202840 10.167740 0.292191 0.248464 0.190341 0.154945 \n", "42 0.833333 9.271142 10.243310 0.305165 0.237141 0.186167 0.152458 \n", "... ... ... ... ... ... ... ... \n", "1364 0.666667 7.831994 8.716208 0.487388 0.249146 0.135395 0.081204 \n", "1365 0.500000 7.876341 8.872487 0.490522 0.261934 0.124533 0.075827 \n", "1366 0.166667 7.913739 9.061492 0.477027 0.272603 0.132026 0.073196 \n", "1367 0.333333 7.977293 9.234155 0.462648 0.277990 0.141861 0.072411 \n", "1368 0.166667 7.882274 9.347926 0.454770 0.282207 0.146730 0.070168 \n", "\n", " age_5 educ age_median \n", "38 0.088237 4.988 26.799999 \n", "39 0.097967 5.214 27.200001 \n", "40 0.107299 5.876 27.400000 \n", "41 0.114059 5.845 27.299999 \n", "42 0.119069 6.618 27.200001 \n", "... ... ... ... \n", "1364 0.046867 2.147 15.600000 \n", "1365 0.047184 2.816 15.400000 \n", "1366 0.045149 2.828 16.000000 \n", "1367 0.045090 4.087 16.700001 \n", "1368 0.046125 4.428 17.000000 \n", "\n", "[679 rows x 10 columns]" ], "text/html": [ "
\n", " | dem_ind | \n", "log_gdppc | \n", "log_pop | \n", "age_1 | \n", "age_2 | \n", "age_3 | \n", "age_4 | \n", "age_5 | \n", "educ | \n", "age_median | \n", "
---|---|---|---|---|---|---|---|---|---|---|
38 | \n", "0.530000 | \n", "8.905374 | \n", "9.933823 | \n", "0.307737 | \n", "0.241766 | \n", "0.210866 | \n", "0.151395 | \n", "0.088237 | \n", "4.988 | \n", "26.799999 | \n", "
39 | \n", "0.166667 | \n", "9.016160 | \n", "10.011580 | \n", "0.302024 | \n", "0.239600 | \n", "0.207557 | \n", "0.152852 | \n", "0.097967 | \n", "5.214 | \n", "27.200001 | \n", "
40 | \n", "0.833333 | \n", "9.133990 | \n", "10.084220 | \n", "0.293686 | \n", "0.245023 | \n", "0.199240 | \n", "0.154751 | \n", "0.107299 | \n", "5.876 | \n", "27.400000 | \n", "
41 | \n", "0.166667 | \n", "9.202840 | \n", "10.167740 | \n", "0.292191 | \n", "0.248464 | \n", "0.190341 | \n", "0.154945 | \n", "0.114059 | \n", "5.845 | \n", "27.299999 | \n", "
42 | \n", "0.833333 | \n", "9.271142 | \n", "10.243310 | \n", "0.305165 | \n", "0.237141 | \n", "0.186167 | \n", "0.152458 | \n", "0.119069 | \n", "6.618 | \n", "27.200001 | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
1364 | \n", "0.666667 | \n", "7.831994 | \n", "8.716208 | \n", "0.487388 | \n", "0.249146 | \n", "0.135395 | \n", "0.081204 | \n", "0.046867 | \n", "2.147 | \n", "15.600000 | \n", "
1365 | \n", "0.500000 | \n", "7.876341 | \n", "8.872487 | \n", "0.490522 | \n", "0.261934 | \n", "0.124533 | \n", "0.075827 | \n", "0.047184 | \n", "2.816 | \n", "15.400000 | \n", "
1366 | \n", "0.166667 | \n", "7.913739 | \n", "9.061492 | \n", "0.477027 | \n", "0.272603 | \n", "0.132026 | \n", "0.073196 | \n", "0.045149 | \n", "2.828 | \n", "16.000000 | \n", "
1367 | \n", "0.333333 | \n", "7.977293 | \n", "9.234155 | \n", "0.462648 | \n", "0.277990 | \n", "0.141861 | \n", "0.072411 | \n", "0.045090 | \n", "4.087 | \n", "16.700001 | \n", "
1368 | \n", "0.166667 | \n", "7.882274 | \n", "9.347926 | \n", "0.454770 | \n", "0.282207 | \n", "0.146730 | \n", "0.070168 | \n", "0.046125 | \n", "4.428 | \n", "17.000000 | \n", "
679 rows × 10 columns
\n", "