In this article we posit that a lot of scientific computing models are in fact finite state machines. My goal is not to convince you that this has some deeper implication, but rather to discuss the resulting design implications and review the current approaches in many OSS libraries.
Models
Let's build a simple linear regression model using the handy least squares solver in the NumPy library:
import numpy as np
from numpy.linalg import lstsq
class LinearRegression:
def __init__(self):
self.model = np.linalg.lstsq
def fit(self, x: np.ndarray, y: np.ndarray) -> None:
self.model = lstsq(x, y, rcond=None)[0]
def predict(self, x: np.ndarray) -> np.ndarray:
return np.dot(self.model, x)
def __len__(self):
return len(self.model)
We can now do
num_rows = 50
num_features = 3
x = np.random.uniform(size=[num_rows, num_features])*10
# sum of inputs + some noise to make it harder
y = np.sum(x, axis=1) + np.random.normal(size=num_rows, scale=5)
linear_regression = LinearRegression()
linear_regression.fit(x, y)
# output should be 1 + 2 + 3 = 6
linear_regression.predict(np.array([1, 2, 3]))
>>> 6.200 # on my machine
This is a common API followed by most popular ML toolkits, including the scikit-learn estimators, catboost, xgboost & lightgbm!
So what's the issue?
We implicitly assume that our objects will be used in a specific sequential manner, with fit
only being called after initialization, and predict
only being called after fit
. If one were to try to call predict
prior to fit
they'd be faced with an error! We also assume that len
is only called once the model has been fit
.
We can visualize this on a flow chart of available states that can be reached via the method calls available to us in that given state:
{{< mermaid >}}
flowchart LR
y("Uninitialized Model") --> |"**init**()"| h("Initialized Model")
h --> |"fit()"| r("Fit Model")
r --> |"fit()"| r
r --> |"predict()"| su[/"np.ndarray"/]
r --> |"**len**()"| suf[/"int"/]
{{< /mermaid >}}
So from a fitted model we can call predict
and len
and so on.
How is this type of restriction implemented in production code bases? We can turn to some of the OSS libraries mentioned earlier to investigate.
Let's pick sklearn:
from sklearn.linear_model import LinearRegression
model = LinearRegression()
model.predict(np.array([1, 2, 3]))
>>> NotFittedError: This LinearRegression instance is not fitted yet. Call 'fit' with appropriate arguments before using this estimator.
We seem to get a custom NotFittedError
, let's have a look as to how this is implemented by following the call graph.
When we call .predict
, we actually call self._decision_function(X)
which calls a function: check_is_fitted
.
This function claims to do the following in its docstring:
Checks if the estimator is fitted by verifying the presence of
fitted attributes (ending with a trailing underscore) and otherwise
raises a NotFittedError with the given message.
The code seems to support this, we set a variable fitted
which is of type boolean by checking all required attributes.
So it seems like the fitted
state we discussed is manually checked by verifying the presence of expected attributes. While this approach works, I think we can quite naturally see that it feels quite unaesthetic. We are not enforcing the state of our class through some native functionality, but rather forcefully checking the presence of some state by checking that it looks like it's in a fitted state.
Furthermore it forces us to add extra lines of code that have the sole purpose of checking we are within a given state when calling specific methods or properties.
The sklearn approach would look something like:
import numpy as np
from numpy.linalg import lstsq
class NotFittedError(Exception):
pass
class LinearRegression:
def __init__(self):
self.is_fitted = False
self.model = np.linalg.lstsq
def fit(self, x: np.ndarray, y: np.ndarray) -> None:
self.model = lstsq(x, y, rcond=None)[0]
self.is_fitted = True
def predict(self, x: np.ndarray) -> np.ndarray:
if not self.is_fitted:
raise NotFittedError("Model must be .fit() before calling .predict()")
return np.dot(self.model, x)
def __len__(self):
if not self.is_fitted:
raise NotFittedError("Model must be .fit() before calling len()")
return len(self.model)
where we use the is_fitted
boolean attribute with custom error raises whenever a method is inappropriately called.