You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
92 lines
2.7 KiB
92 lines
2.7 KiB
import dataclasses |
|
import pandas as pd |
|
import pytest |
|
|
|
from sklearn import linear_model |
|
|
|
|
|
@dataclasses.dataclass |
|
class Regression: |
|
intercept: float |
|
coefficient: float |
|
score: float |
|
|
|
@property |
|
def coeff(self) -> float: |
|
return self.coefficient |
|
|
|
@property |
|
def r2(self) -> float: |
|
return self.score |
|
|
|
def predict(self, *, x: int | float = None, y: int | float = None) -> float: |
|
"""predict a value if x or y is given""" |
|
if x is not None: |
|
return self.intercept + x * self.coefficient |
|
if y is not None: |
|
return (y - self.intercept) / self.coefficient |
|
msg = "predict() expects a keyword argument 'x' or 'y'" |
|
raise TypeError(msg) |
|
|
|
def to_dict(self): |
|
return dataclasses.asdict(self) |
|
|
|
|
|
def linear_regression(data: pd.DataFrame, *, x: str, y: str) -> Regression: |
|
"""calculates a linear regression for two columns of a DataFrame""" |
|
x_values = data[x].values.reshape(-1, 1) |
|
y_values = data[y].values.reshape(-1, 1) |
|
fit = linear_model.LinearRegression().fit(x_values, y_values) |
|
score = fit.score(x_values, y_values) |
|
return Regression(fit.intercept_[0], fit.coef_[0][0], score) |
|
|
|
|
|
# tests |
|
|
|
|
|
@pytest.fixture() |
|
def example_data() -> pd.DataFrame: |
|
x = list(range(1, 6)) |
|
y = [4.1, 6.9, 10.1, 12.9, 15.9] |
|
return pd.DataFrame({"A": x, "B": y}) |
|
|
|
|
|
def test_linear_regression(example_data): |
|
result = linear_regression(example_data, x="A", y="B") |
|
|
|
assert isinstance(result, Regression) |
|
assert pytest.approx(2.96) == result.coefficient |
|
assert pytest.approx(2.96) == result.coeff |
|
assert pytest.approx(1.1) == result.intercept |
|
assert pytest.approx(0.9996349) == result.score |
|
assert pytest.approx(0.9996349) == result.r2 |
|
|
|
|
|
def test_regression_predict(example_data): |
|
regression = linear_regression(example_data, x="A", y="B") |
|
|
|
prediction = regression.predict(x=10) |
|
|
|
assert pytest.approx(30.7) == prediction |
|
assert pytest.approx(10) == regression.predict(y=prediction) |
|
|
|
|
|
def test_regression_predict_exceptions(example_data): |
|
regression = linear_regression(example_data, x="A", y="B") |
|
|
|
with pytest.raises(TypeError, match="expects a keyword"): |
|
regression.predict() |
|
|
|
with pytest.raises(TypeError, match="takes 1 positional argument but"): |
|
regression.predict(1) |
|
|
|
|
|
def test_regression_to_dict(example_data): |
|
regression = linear_regression(example_data, x="A", y="B") |
|
|
|
result = regression.to_dict() |
|
|
|
assert sorted(result.keys()) == ["coefficient", "intercept", "score"] |
|
assert pytest.approx(2.96) == result["coefficient"] |
|
assert pytest.approx(1.1) == result["intercept"] |
|
assert pytest.approx(0.9996349) == result["score"]
|
|
|