Last active
February 16, 2017 03:34
-
-
Save pataiji/cf2d7cea4aebac42a7c39af556cf3d4b to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env ruby | |
require 'matrix' | |
CLASS_1 = 1 | |
CLASS_2 = 2 | |
# 教師データ | |
data = Matrix.rows([ | |
[ 1.0, CLASS_1], | |
[ 0.5, CLASS_1], | |
[-0.2, CLASS_2], | |
[-0.4, CLASS_1], | |
[-1.3, CLASS_2], | |
[-2.0, CLASS_2], | |
]) | |
features = Matrix.rows(data.column(0).map { |v| Vector.elements([v]) }) # 特徴ベクトル行列 | |
labels = data.column(1) | |
wvec = Vector.elements([0.2, 0.3]) # 初期の重みベクトル | |
xvecs = Matrix.hstack(Matrix.build(features.row_count, 1) { 1 }, features) # xvec[0] = 1 | |
# 重み係数の学習 | |
def train(xvecs, wvec, tvec) | |
low = 0.2 # 学習係数 | |
j = 0 | |
wvec.each_with_index do |w, key| | |
sum = 0 | |
xvecs.row_count.times do |i| | |
xvec = xvecs.row(i) | |
b = tvec[i] | |
wx = wvec.dot(xvec) | |
sum += (wx - b) * xvec[key] | |
j += (wx - b) ** 2 | |
wvec = wvec.collect do |v| | |
if wvec[key] == v | |
wvec[key] - low * sum | |
else | |
v | |
end | |
end | |
end | |
end | |
puts j / 2 | |
wvec | |
end | |
# 重み係数を求める | |
def get_wvec(xvecs, wvec, tvec) | |
loop = 100 | |
loop.times.each do | |
wvec = train(xvecs, wvec, tvec) | |
end | |
wvec | |
end | |
def detect_class(wvec1, wvec2, x) | |
g1 = wvec1[1] * x + wvec1[0] | |
g2 = wvec2[1] * x + wvec2[0] | |
return CLASS_1 if g1 > g2 | |
return CLASS_2 if g1 < g2 | |
end | |
# クラス1について | |
# クラス1の教師ベクトル | |
tvec1 = labels.clone.map do |c| | |
if c == CLASS_1 | |
1 | |
elsif c == CLASS_2 | |
0 | |
end | |
end | |
wvec1 = get_wvec(xvecs, wvec, tvec1) | |
puts 'wvec1 = %s' % wvec1 | |
puts 'g1(x) = %f x + %f' % [wvec1[1], wvec1[0]] | |
# xvecs.row_count.times do |i| | |
# xvec = xvecs.row(i) | |
# label = labels[i] | |
# puts 'g1(%s) = %s (クラス:%s)' % [xvec[1], wvec1.dot(xvec), label] | |
# end | |
# クラス2について | |
# クラス2の教師ベクトル | |
tvec2 = labels.clone.map do |c| | |
if c == CLASS_1 | |
0 | |
elsif c == CLASS_2 | |
1 | |
end | |
end | |
wvec2 = get_wvec(xvecs, wvec, tvec2) | |
puts 'wvec2 = %s' % wvec2 | |
puts 'g2(x) = %f x + %f' % [wvec2[1], wvec2[0]] | |
# xvecs.row_count.times do |i| | |
# xvec = xvecs.row(i) | |
# label = labels[i] | |
# puts 'g2(%s) = %s (クラス:%s)' % [xvec[1], wvec2.dot(xvec), label] | |
# end | |
[-0.4, -0.39, -0.38, -0.37, -0.35, -0.3, -0.2].each do |x| | |
klass = detect_class(wvec1, wvec2, x) | |
puts "#{x.to_s}: class #{klass}" | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment