Created
September 9, 2017 06:43
-
-
Save doloopwhile/0b7534b1019b3c76da4b8d5be4e0cf39 to your computer and use it in GitHub Desktop.
ワインの等級を当ててみた
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# インターネットから直接データをダウンロードする | |
import matplotlib.pyplot as plt | |
import pandas as pd | |
import numpy as np | |
import datetime as dt | |
import statsmodels.formula.api as sm | |
import itertools | |
import math | |
def separate_data_frame(df, n): | |
size = len(df) // 10 + 1 | |
data_frames = [] | |
for i in range(10): | |
start = size * i | |
end = size * (i + 1) | |
data_frames.append(df[start:end]) | |
return data_frames | |
def read_data_frames(): | |
# ワインの諸々の数値と等級のデータ | |
url = 'https://archive.ics.uci.edu/ml/machine-learning-databases/wine-quality/winequality-red.csv' | |
df = pd.read_csv(url, sep=";") | |
df.columns = list(range(1, 13)) | |
x = df[list(range(1, 12))] | |
y = df[12] | |
return x, y | |
def fit_model(nx, y): | |
# 回帰実行 | |
model = sm.OLS(y, nx) | |
return model.fit() | |
def solve(result, nx): | |
return nx.dot(result.params).round(0) | |
def accuracy(result, nx, y): | |
errors = solve(result, nx) - y | |
n = 0 | |
for _, v in errors.iteritems(): | |
if int(v) == 0: n += 1 | |
return n / len(nx) | |
# 線形回帰で解いてみる | |
x, y = read_data_frames() # x: ワインの情報のDataFrame、y: 等級のDataFrame | |
# xの各列を正規化する | |
print("入力の平均", x.mean()) | |
print("入力の分散", x.std()) | |
nx = x.apply(lambda col: (col - col.mean()) / col.std(), axis='index').fillna(0) | |
# 低数列を追加 | |
nx[0] = np.ones(len(nx)) | |
print("正規化した入力の平均", nx.mean()) | |
print("正規化した入力の分散", nx.std()) | |
# 実際のデータをテスト的に表示 | |
print("入力", nx[0:1]) | |
print("正規化した入力", nx[0:1]) | |
print("正解", y[0:1]) | |
# 10分割する | |
nx_frames = separate_data_frame(nx, 10) | |
y_frames = separate_data_frame(y, 10) | |
# それぞれをテストデータとして、10回線形回帰を行う | |
accuracies = [] | |
for n, _ in enumerate(nx_frames): | |
print("{}回目の試行:".format(n)) | |
nx_df = pd.concat([df for i, df in enumerate(nx_frames) if i != n]) | |
y_df = pd.concat([df for i, df in enumerate(y_frames) if i != n]) | |
# 線形回帰 | |
result = fit_model(nx_df, y_df) | |
# テストデータに線形回帰を実施して正答率を計る | |
test_nx = nx_frames[n] | |
test_y = y_frames[n] | |
acc = accuracy(result, test_nx, test_y) | |
print("正答率={}%, accuracy={}".format(math.floor(acc * 100), acc)) | |
accuracies.append(acc) | |
accuracies = pd.Series(accuracies) | |
print("平均正答率={}%, mean of accuracy={}".format(math.floor(accuracies.mean() * 100), accuracies.mean())) | |
# => 平均正答率=58%, mean of accuracy=0.589123427672956 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
\$ python solve4.py | |
入力の平均 1 8.319637 | |
2 0.527821 | |
3 0.270976 | |
4 2.538806 | |
5 0.087467 | |
6 15.874922 | |
7 46.467792 | |
8 0.996747 | |
9 3.311113 | |
10 0.658149 | |
11 10.422983 | |
dtype: float64 | |
入力の分散 1 1.741096 | |
2 0.179060 | |
3 0.194801 | |
4 1.409928 | |
5 0.047065 | |
6 10.460157 | |
7 32.895324 | |
8 0.001887 | |
9 0.154386 | |
10 0.169507 | |
11 1.065668 | |
dtype: float64 | |
正規化した入力の平均 1 -1.860787e-17 | |
2 1.360874e-16 | |
3 -1.013712e-16 | |
4 -9.970483e-17 | |
5 3.143202e-16 | |
6 -7.036967e-17 | |
7 1.304634e-16 | |
8 2.361144e-14 | |
9 2.766670e-15 | |
10 7.152919e-16 | |
11 5.203953e-16 | |
0 1.000000e+00 | |
dtype: float64 | |
正規化した入力の分散 1 1.0 | |
2 1.0 | |
3 1.0 | |
4 1.0 | |
5 1.0 | |
6 1.0 | |
7 1.0 | |
8 1.0 | |
9 1.0 | |
10 1.0 | |
11 1.0 | |
0 0.0 | |
dtype: float64 | |
入力 1 2 3 4 5 6 7 \ | |
0 -0.528194 0.961576 -1.391037 -0.453077 -0.24363 -0.466047 -0.379014 | |
8 9 10 11 0 | |
0 0.5581 1.28824 -0.579025 -0.959946 1.0 | |
正規化した入力 1 2 3 4 5 6 7 \ | |
0 -0.528194 0.961576 -1.391037 -0.453077 -0.24363 -0.466047 -0.379014 | |
8 9 10 11 0 | |
0 0.5581 1.28824 -0.579025 -0.959946 1.0 | |
正解 0 5 | |
Name: 12, dtype: int64 | |
0回目の試行: | |
正答率=60%, accuracy=0.60625 | |
1回目の試行: | |
正答率=51%, accuracy=0.5125 | |
2回目の試行: | |
正答率=58%, accuracy=0.58125 | |
3回目の試行: | |
正答率=55%, accuracy=0.55625 | |
4回目の試行: | |
正答率=61%, accuracy=0.6125 | |
5回目の試行: | |
正答率=58%, accuracy=0.5875 | |
6回目の試行: | |
正答率=53%, accuracy=0.5375 | |
7回目の試行: | |
正答率=61%, accuracy=0.6125 | |
8回目の試行: | |
正答率=68%, accuracy=0.6875 | |
9回目の試行: | |
正答率=59%, accuracy=0.5974842767295597 | |
平均正答率=58, mean of accuracy=0.589123427672956 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment