diff --git a/README.md b/README.md index ac7a0a7..3497ea3 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,22 @@ # snippets -Misc code snippets \ No newline at end of file +Misc code snippets I sometimes need and always have to look up how it works... + + +## linear_regression.py + +Calculate the linear regression on two columns of a data frame. The resulting +object has the function `predict()` to calculate x or y values for a given +counterpart. + +```python + from linear_regression import linear_regression + df = pd.DataFrame({"temperature":[...], "signal":[...]}) + + regression = linear_regression(df, x="temperature", y="signal") + + repr(regression) == "Regression(intercept=1, coefficient=3, score=0.9998)" + + regression.predict(x=3) == 10 + regression.predict(y=7) == 2 +``` \ No newline at end of file diff --git a/linear_regression.py b/linear_regression.py new file mode 100644 index 0000000..7eb2e86 --- /dev/null +++ b/linear_regression.py @@ -0,0 +1,71 @@ +import pandas as pd +import pytest + +from dataclasses import dataclass +from sklearn import linear_model + + +@dataclass +class Regression: + intercept: float + coefficient: float + score: float + + @property + def coeff(self) -> float: + return self.coefficient + + @property + def r2(self) -> float: + return self.score + + def predict(self, x: int | float = None, y: int | float = None) -> float: + """predict a value if x or y is given""" + if x is not None: + return self.intercept + x * self.coefficient + if y is not None: + return (y - self.intercept) / self.coefficient + msg = "predict() expects 1 argument, got 0" + raise TypeError(msg) + + +def linear_regression(data: pd.DataFrame, *, x: str, y: str) -> Regression: + """calculates a linear regression for two columns of a DataFrame""" + x_values = data[x].values.reshape(-1, 1) + y_values = data[y].values.reshape(-1, 1) + fit = linear_model.LinearRegression().fit(x_values, y_values) + score = fit.score(x_values, y_values) + return Regression(fit.intercept_[0], fit.coef_[0][0], score) + + +# tests + + +@pytest.fixture() +def example_data() -> pd.DataFrame: + x = list(range(1, 6)) + y = [4.1, 6.9, 10.1, 12.9, 15.9] + return pd.DataFrame({"A": x, "B": y}) + + +def test_linear_regression(example_data): + result = linear_regression(example_data, x="A", y="B") + + assert isinstance(result, Regression) + assert pytest.approx(2.96) == result.coefficient + assert pytest.approx(2.96) == result.coeff + assert pytest.approx(1.1) == result.intercept + assert pytest.approx(0.9996349) == result.score + assert pytest.approx(0.9996349) == result.r2 + + +def test_regression_predict(example_data): + result = linear_regression(example_data, x="A", y="B") + + prediction = result.predict(10) + + assert pytest.approx(30.7) == prediction + assert pytest.approx(10) == result.predict(y=prediction) + + with pytest.raises(TypeError, match="expects 1 argument"): + result.predict()