15 min read
Simple Models in Python
Building Linear Regression and Logistic Regression models with Scikit-Learn
What You'll Learn
- Scikit-Learn basics
- Linear Regression (predicting numbers)
- Logistic Regression (predicting categories)
- Model evaluation metrics
Scikit-Learn (sklearn)
The standard library for machine learning in Python.
Standard API:
- Import model:
from sklearn.family import Model - Instantiate:
model = Model(hyperparameters) - Fit:
model.fit(X_train, y_train) - Predict:
model.predict(X_test)
Linear Regression
Used for predicting continuous values.
code.py
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error, r2_score
import pandas as pd
import seaborn as sns
# 1. Load Data
df = sns.load_dataset('tips')
X = df[['total_bill']] # Features (2D array)
y = df['tip'] # Target (1D array)
# 2. Split Data
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# 3. Initialize and Train
model = LinearRegression()
model.fit(X_train, y_train)
# 4. Predict
y_pred = model.predict(X_test)
# 5. Evaluate
print("MSE:", mean_squared_error(y_test, y_pred))
print("R2 Score:", r2_score(y_test, y_pred))Logistic Regression
Used for classification (despite the name!).
code.py
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, confusion_matrix
# 1. Load Data (Titanic)
df = sns.load_dataset('titanic').dropna(subset=['age', 'fare', 'survived'])
X = df[['age', 'fare']]
y = df['survived']
# 2. Split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# 3. Train
clf = LogisticRegression()
clf.fit(X_train, y_train)
# 4. Predict
y_pred = clf.predict(X_test)
# 5. Evaluate
print("Accuracy:", accuracy_score(y_test, y_pred))
print("Confusion Matrix:\n", confusion_matrix(y_test, y_pred))Practice Exercise
Try to predict the price of diamonds!
code.py
from sklearn.linear_model import LinearRegression
import seaborn as sns
diamonds = sns.load_dataset('diamonds')
# Use 'carat' to predict 'price'
# ... your code here ...Next Steps
Learn how to make your analysis reproducible and shareable.
Practice & Experiment
Test your understanding by running Python code directly in your browser. Try the examples from the article above!