Skip to content

Instantly share code, notes, and snippets.

@pataiji
Last active February 16, 2017 03:34
Show Gist options
  • Save pataiji/cf2d7cea4aebac42a7c39af556cf3d4b to your computer and use it in GitHub Desktop.
Save pataiji/cf2d7cea4aebac42a7c39af556cf3d4b to your computer and use it in GitHub Desktop.
#!/usr/bin/env ruby
require 'matrix'
CLASS_1 = 1
CLASS_2 = 2
# 教師データ
data = Matrix.rows([
[ 1.0, CLASS_1],
[ 0.5, CLASS_1],
[-0.2, CLASS_2],
[-0.4, CLASS_1],
[-1.3, CLASS_2],
[-2.0, CLASS_2],
])
features = Matrix.rows(data.column(0).map { |v| Vector.elements([v]) }) # 特徴ベクトル行列
labels = data.column(1)
wvec = Vector.elements([0.2, 0.3]) # 初期の重みベクトル
xvecs = Matrix.hstack(Matrix.build(features.row_count, 1) { 1 }, features) # xvec[0] = 1
# 重み係数の学習
def train(xvecs, wvec, tvec)
low = 0.2 # 学習係数
j = 0
wvec.each_with_index do |w, key|
sum = 0
xvecs.row_count.times do |i|
xvec = xvecs.row(i)
b = tvec[i]
wx = wvec.dot(xvec)
sum += (wx - b) * xvec[key]
j += (wx - b) ** 2
wvec = wvec.collect do |v|
if wvec[key] == v
wvec[key] - low * sum
else
v
end
end
end
end
puts j / 2
wvec
end
# 重み係数を求める
def get_wvec(xvecs, wvec, tvec)
loop = 100
loop.times.each do
wvec = train(xvecs, wvec, tvec)
end
wvec
end
def detect_class(wvec1, wvec2, x)
g1 = wvec1[1] * x + wvec1[0]
g2 = wvec2[1] * x + wvec2[0]
return CLASS_1 if g1 > g2
return CLASS_2 if g1 < g2
end
# クラス1について
# クラス1の教師ベクトル
tvec1 = labels.clone.map do |c|
if c == CLASS_1
1
elsif c == CLASS_2
0
end
end
wvec1 = get_wvec(xvecs, wvec, tvec1)
puts 'wvec1 = %s' % wvec1
puts 'g1(x) = %f x + %f' % [wvec1[1], wvec1[0]]
# xvecs.row_count.times do |i|
# xvec = xvecs.row(i)
# label = labels[i]
# puts 'g1(%s) = %s (クラス:%s)' % [xvec[1], wvec1.dot(xvec), label]
# end
# クラス2について
# クラス2の教師ベクトル
tvec2 = labels.clone.map do |c|
if c == CLASS_1
0
elsif c == CLASS_2
1
end
end
wvec2 = get_wvec(xvecs, wvec, tvec2)
puts 'wvec2 = %s' % wvec2
puts 'g2(x) = %f x + %f' % [wvec2[1], wvec2[0]]
# xvecs.row_count.times do |i|
# xvec = xvecs.row(i)
# label = labels[i]
# puts 'g2(%s) = %s (クラス:%s)' % [xvec[1], wvec2.dot(xvec), label]
# end
[-0.4, -0.39, -0.38, -0.37, -0.35, -0.3, -0.2].each do |x|
klass = detect_class(wvec1, wvec2, x)
puts "#{x.to_s}: class #{klass}"
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment