Created
April 27, 2019 01:26
-
-
Save jiahao/5dc9242e3e6d0c8e5c59e2251c07a955 to your computer and use it in GitHub Desktop.
Multinomial naive Bayes in Julia, allowing for generic numeric types for the conditional probabilities. (including rational numbers) that allow you to calculate exact probabilities.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
struct NaiveBayes{T, V<:AbstractVector, M<:AbstractMatrix} | |
probabilities::M | |
priors::V | |
end | |
train(::Type{NaiveBayes}, T::Type{R}, features, labels, α = 1) where R<:Real = | |
train(NaiveBayes{T, Vector{T}, Matrix{T}}, features, labels, α) | |
for (typ, op) in ((Rational, ://), (Real, :/)) @eval begin | |
function train(::Type{NaiveBayes{T, S, R}}, | |
features, labels, α = 1) where | |
{T<:$typ, S<:AbstractVector{T}, R<:AbstractMatrix{T}} | |
m, n = size(features) | |
pr = zeros(T, 2, n) #probabilities | |
#Calculate priors and denominators | |
counts = zeros(Int, 2) | |
counts_category = zeros(Int, 2) | |
for i in 1:m | |
k = if labels[i] 1 else 2 end | |
counts[k] += 1 | |
counts_category[k] += sum(view(features, i, :)) | |
end | |
priors = ($op)(counts, m) | |
for j in 1:n | |
counts = zeros(Int, 2) | |
for i in 1:m | |
k = if labels[i] 1 else 2 end | |
counts[k] += features[i, j] | |
end | |
for k in 1:2 | |
pr[k, j] = ($op)(counts[k] + α, counts_category[k] + α * n) | |
end | |
end | |
return NaiveBayes{T, S, R}(pr, priors) | |
end | |
function predict(NB::NaiveBayes, T::Type{R}, features, | |
normalize::Bool=false) where R<:($typ) | |
pr = convert(Vector{T}, NB.priors) | |
n = size(features, 1) | |
for k=1:2, j=1:n | |
x = features[j] | |
if x != 0 | |
pr[k] *= (NB.probabilities[k, j])^x | |
end | |
end | |
if normalize | |
pr = ($op)(pr, sum(pr)) | |
end | |
return T(pr[1]), T(pr[2]) | |
end | |
end | |
end | |
predict(NB::NaiveBayes, features) = predict(NB, Float32, features) | |
#Example from | |
#https://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html | |
featureset = | |
#Chinese Beijing Shanghai Macao Japan Tokyo | |
[2 1 0 0 0 0 | |
2 0 1 0 0 0 | |
1 0 0 1 0 0 | |
1 0 0 0 1 1] | |
labels = [true, true, true, false] | |
NB = train(NaiveBayes, Rational, featureset, labels) | |
@assert NB.priors == [3//4, 1//4] | |
@assert NB.probabilities[1, 1] == 3//7 | |
@assert NB.probabilities[1, 5] == 1//14 | |
@assert NB.probabilities[1, 6] == 1//14 | |
@assert NB.probabilities[2, 1] == 2//9 | |
@assert NB.probabilities[2, 5] == 2//9 | |
@assert NB.probabilities[2, 6] == 2//9 | |
pred_yes, pred_no = predict(NB, Rational, [3, 0, 0, 0, 1, 1]) | |
#Unnormalized probabilities | |
upr_yes = 3//4 * (3//7)^3 * 1//14 * 1//14 | |
upr_no = 1//4 * (2//9)^3 * 2//9 * 2//9 | |
pr_yes = upr_yes // (upr_yes + upr_no) | |
pr_no = upr_no // (upr_yes + upr_no) | |
@assert pred_yes == pr_yes | |
@assert pred_no == pr_no | |
NBF = train(NaiveBayes, Float32, featureset, labels) | |
predict(NB, Float32, [3, 0, 0, 0, 1, 1]) | |
predict(NBF, Float32, [3, 0, 0, 0, 1, 1]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment