Created
June 28, 2024 19:45
-
-
Save peterbe/33e164684dcb15898a18d5bcb665e8d2 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
function defaultTokenizer(text: string) { | |
//remove punctuation from text - remove anything that isn't a word char or a space | |
var rgxPunctuation = /[^(a-zA-ZA-Яa-я0-9_)+\s]/g; | |
var sanitized = text.replace(rgxPunctuation, " "); | |
return sanitized.split(/\s+/); | |
} | |
/** | |
* Naive-Bayes Classifier | |
* | |
* This is a naive-bayes classifier that uses Laplace Smoothing. | |
* | |
* Takes an (optional) options object containing: | |
* - `tokenizer` => custom tokenization function | |
* | |
*/ | |
type Options = { | |
tokenizer?: (text: string) => string[]; | |
}; | |
type TokenizerFunction = (text: string) => string[]; | |
export class Naivebayes { | |
tokenizer: TokenizerFunction; | |
vocabulary: { [key: string]: boolean }; // XXX change to Set | |
vocabularySize: number; // XXX Shouldn't be needed | |
totalDocuments: number; | |
docCount: { [key: string]: number }; | |
wordCount: { [key: string]: number }; | |
wordFrequencyCount: { [key: string]: { [key: string]: number } }; | |
categories: { [key: string]: boolean }; | |
constructor({ tokenizer = defaultTokenizer }: Options = {}) { | |
this.tokenizer = tokenizer; | |
//initialize our vocabulary and its size | |
this.vocabulary = {}; | |
this.vocabularySize = 0; | |
//number of documents we have learned from | |
this.totalDocuments = 0; | |
//document frequency table for each of our categories | |
//=> for each category, how often were documents mapped to it | |
this.docCount = {}; | |
//for each category, how many words total were mapped to it | |
this.wordCount = {}; | |
//word frequency table for each category | |
//=> for each category, how frequent was a given word mapped to it | |
this.wordFrequencyCount = {}; | |
//hashmap of our category names | |
this.categories = {}; | |
// // set options object | |
} | |
initializeCategory(categoryName: string) { | |
if (!this.categories[categoryName]) { | |
this.docCount[categoryName] = 0; | |
this.wordCount[categoryName] = 0; | |
this.wordFrequencyCount[categoryName] = {}; | |
this.categories[categoryName] = true; | |
} | |
return this; | |
} | |
learn(text: string, category: string) { | |
this.initializeCategory(category); | |
this.docCount[category]++; | |
this.totalDocuments++; | |
const tokens = this.tokenizer(text); | |
const frequencyTable = this.frequencyTable(tokens); | |
const self = this; | |
Object.keys(frequencyTable).forEach((token) => { | |
if (!this.vocabulary[token]) { | |
self.vocabulary[token] = true; | |
self.vocabularySize++; | |
} | |
const frequencyInText = frequencyTable[token]; | |
if (!self.wordFrequencyCount[category][token]) | |
self.wordFrequencyCount[category][token] = frequencyInText; | |
else self.wordFrequencyCount[category][token] += frequencyInText; | |
self.wordCount[category] += frequencyInText; | |
}); | |
} | |
frequencyTable(tokens: string[]) { | |
const frequencyTable = Object.create(null); | |
tokens.forEach(function (token) { | |
if (!frequencyTable[token]) frequencyTable[token] = 1; | |
else frequencyTable[token]++; | |
}); | |
return frequencyTable; | |
} | |
categorize(text: string): string { | |
var self = this, | |
maxProbability = -Infinity, | |
chosenCategory = ""; | |
var tokens = self.tokenizer(text); | |
var frequencyTable = self.frequencyTable(tokens); | |
Object.keys(self.categories).forEach(function (category) { | |
var categoryProbability = self.docCount[category] / self.totalDocuments; | |
var logProbability = Math.log(categoryProbability); | |
Object.keys(frequencyTable).forEach(function (token) { | |
var frequencyInText = frequencyTable[token]; | |
var tokenProbability = self.tokenProbability(token, category); | |
logProbability += frequencyInText * Math.log(tokenProbability); | |
}); | |
if (logProbability > maxProbability) { | |
maxProbability = logProbability; | |
chosenCategory = category; | |
} | |
}); | |
return chosenCategory; | |
} | |
tokenProbability(token: string, category: string) { | |
var wordFrequencyCount = this.wordFrequencyCount[category][token] || 0; | |
var wordCount = this.wordCount[category]; | |
return (wordFrequencyCount + 1) / (wordCount + this.vocabularySize); | |
} | |
toJSON() { | |
throw new Error("Method not implemented."); | |
} | |
fromJSON(jsonStr: string) { | |
throw new Error("Method not implemented."); | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import { readFileSync } from "fs"; | |
import util from "node:util"; | |
import chalk from "chalk"; | |
import { Naivebayes } from "./naive-bayes"; | |
const args = process.argv.slice(2); | |
const trainedFilePath = args[0] as string; | |
const sampleFilePath = args[1] as string; | |
type Training = { | |
rating: string; | |
comment: string; | |
}; | |
type Sample = { | |
comment: string; | |
rating: number; | |
}; | |
async function main() { | |
const trained = JSON.parse(readFileSync(trainedFilePath, "utf8"))[ | |
"comments" | |
] as Training[]; | |
console.log("TRAINED", trained.length); | |
const samples = JSON.parse(readFileSync(sampleFilePath, "utf8")) as Sample[]; | |
console.log("SAMPLES", samples.length); | |
const goodSamples = samples.filter((s) => s.rating >= 0.9); | |
console.log("GOOD SAMPLES", goodSamples.length); | |
const bayes = new Naivebayes(); | |
// console.log(bayes); | |
// console.log(trained); | |
const countTraining: Record<string, number> = {}; | |
for (const training of trained) { | |
// bayes.learn(key, trained[key]); | |
bayes.learn(training.comment, training.rating); | |
countTraining[training.rating] = countTraining[training.rating] | |
? countTraining[training.rating] + 1 | |
: 1; | |
} | |
console.log("COUNT TRAINING:", countTraining); | |
return; | |
const count: Record<string, number> = {}; | |
for (const sample of goodSamples) { | |
// console.log(sample.comment); | |
const category = bayes.categorize(sample.comment); | |
count[category] = count[category] ? count[category] + 1 : 1; | |
if (category === "good") { | |
console.log(chalk.grey(util.inspect(sample.comment))); | |
console.log(`CATEGORY: ${category}`); | |
console.log("\n"); | |
} | |
} | |
console.log("COUNT:", count); | |
} | |
main(); |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment