Leverage Python Inheritance in ML projects
Introduction
Many people approaching Machine Learning don't have a strong background in computer engineering, and when they need to work on a real product their code can be messy and difficult to manage. This is why I always strongly recommend learning to use coding best practices which will enable you to work smoothly within a team and level up the project you're working on. Today I want to talk about Python inheritance and show some simple examples of how to use it within the field of Machine Learning.
In software development and other information technology fields, technical debt (also known as design debt or code debt) is the implied cost of future reworking because a solution prioritizes expedience over long-term design.
If you are interested in learning more about design patterns you might be interested in some of my previous articles.
Python Inheritance
Inheritance it's not just a Python concept but a general concept in Object Oriented Programming. So in this tutorial, we have to deal with classes and objects which is a programming paradigm not very used in Python with respect to other languages like Java.
In OOP, we can define a general class representing something in the world, for example, a Person which we simply define by a name, surname and age in the following way.
Python">class Person:
def __init__(self, name, surname, age):
self.name = name
self.surname = surname
self.age = age
def __str__(self):
return f"Name: {self.name}, surname: {self.surname}, age: {self.age}"
def grow(self):
self.age +=1
In this class, we defined a simple constructor ( init). Then we defined the str method, which will take care of printing the object in the way we desire. Finally, we have the grow() method to make the person one year older.
Now we can instantiate an object and use this class.
person = Person("Marcello", "Politi", 28)
person.grow()
print(person)
# output wiil be
# Name: Marcello, surname: Politi, age: 29
Now what if we want to define a particular type of person, for example, a worker? Well, we can do the same thing as before, but we add another input variable to add its salary.
class Worker:
def __init__(self, name, surname, age, salary):
self.name = name
self.surname = surname
self.age = age
self.salary = salary
def __str__(self):
return f"Name: {self.name}, surname: {self.surname}, age: {self.age}, salary: {self.salary}"
def grow(self):
self.age +=1
That's it. But is this the best way to implement this? You see that most of the Worker code is the same as the Person code, this is because a worker is a particular person, and then it shares many things in common with a person.
What we can do, is to tell Python that the worker should inherit everything from the Person, and then manually add all the things we need, that a general person doesn't have.
class Worker(Person):
def __init__(self, name, surname, age, salary):
super().__init__(name, surname, age)
self.salary = salary
def __str__(self):
text = super().__str__()
return text + f",salary: {self.salary}"
In the worker class, the constructor calls the constructor of the person class leveraging the super() keyword and then adds also the salary variable.
Same thing when defining the str method. We use the same text return from Person using the super keyword, and add the salary when printing the object.
Inheritance in Machine Learning
There are no rules on when to use inheritance in Machine Learning. I don't know what project you're working on, or what your code looks like. I just want to stress the fact that you should adopt an OOP paradigm in your codebase. But still, let's see some examples of how to use inheritance.
Define a BaseModel
Let's code a base machine learning model class that is defined by some standard variable. This class then will have a method to load the data, one to train, another to evaluate, and one to preprocess the data. However, each specific model will preprocess the data differently, so the subclasses that will inherit the base model shall rewrite the preprocessing method. Be alert, the BaseMLModel itself inherit the ABC class. This is a way to tell Python that this class is an abstract class, and shall not be used, but it's only a template to build subclasses.
The same is true for the _preprocess_traindata which is marked a @abstactmethod. This means that subclasses must reimplement this method.
Check this video to learn more about abstract classes and methods:
from abc import ABC, abstractmethod
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.datasets import load_iris
import numpy as np
class BaseMLModel(ABC):
def __init__(self, test_size=0.2, random_state=42):
self.model = None # This will be set in subclasses
self.test_size = test_size
self.random_state = random_state
self.X_train = None
self.X_test = None
self.y_train = None
self.y_test = None
def load_data(self, X, y):
self.X_train, self.X_test, self.y_train, self.y_test = train_test_split(
X, y, test_size=self.test_size, random_state=self.random_state
)
@abstractmethod
def preprocess_train_data(self):
"""Each model can define custom preprocessing for training data."""
pass
def train(self):
self.X_train, self.y_train = self.preprocess_train_data()
self.model.fit(self.X_train, self.y_train)
def evaluate(self):
predictions = self.model.predict(self.X_test)
return accuracy_score(self.y_test, predictions)
Now let's see how we can inherit from this class. First, we can implement a LogisticRegressionModel. Which will have its own preprocessing algorithm.
class LogisticRegressionModel(BaseMLModel):
def __init__(self, **kwargs):
super().__init__()
self.model = LogisticRegression(**kwargs)
def preprocess_train_data(self):
#Standardize features for Logistic Regression
mean = self.X_train.mean(axis=0)
std = self.X_train.std(axis=0)
X_train_scaled = (self.X_train - mean) / std
return X_train_scaled, self.y_train
Then we can define as many subclasses as we want. I define here one for a Random Forest.
class RandomForestModel(BaseMLModel):
def __init__(self, n_important_features=2, **kwargs):
super().__init__()
self.model = RandomForestClassifier(**kwargs)
self.n_important_features = n_important_features
def preprocess_train_data(self):
#Select top `n_important_features` features based on variance
feature_variances = np.var(self.X_train, axis=0)
top_features_indices = np.argsort(feature_variances)[-self.n_important_features:]
X_train_selected = self.X_train[:, top_features_indices]
return X_train_selected, self.y_train
Then we can use all of this in our main function:
if __name__ == "__main__":
# Load dataset
data = load_iris()
X, y = data.data, data.target
# Logistic Regression
log_reg_model = LogisticRegressionModel(max_iter=200)
log_reg_model.load_data(X, y)
log_reg_model.train()
print(f"Logistic Regression Accuracy: {log_reg_model.evaluate()}")
# Random Forest
rf_model = RandomForestModel(n_estimators=100, n_important_features=3)
rf_model.load_data(X, y)
rf_model.train()
print(f"Random Forest Accuracy: {rf_model.evaluate()}")
Final Thoughts
One of the main benefits of Python's inheritance in ML projects is in the design of modular, maintainable, and scalable codebases. Inheritance helps avoid redundant code by writing common logic in a base class, such as BaseMLModel. Therefore reducing code duplication. Inheritance also makes it easy to encapsulate common behaviours in a base class, allowing subclasses to define particular details.
The main benefit in my opinion is that a well-organized, object-oriented codebase allows multiple developers within a team to work independently on separate parts. In our example, a lead engineer could define the base model, and then each developer could focus on a single algorithm and write the subclass.
Before diving into complex design patterns, focus on leveraging OOP best practices. Doing so will make you a better programmer compared to many others in the ML field.
Follow me on Medium if you like this article!