Skip to content

Instantly share code, notes, and snippets.

@lethalbrains
Created August 8, 2018 06:51
Show Gist options
  • Save lethalbrains/5e4f81dd506b7a4b17dabcbbf88706c0 to your computer and use it in GitHub Desktop.
Save lethalbrains/5e4f81dd506b7a4b17dabcbbf88706c0 to your computer and use it in GitHub Desktop.
def __find_best_split_for_column(self, col):
x = self.data[col]
unique_values = x.unique()
if len(unique_values) == 1: return None, None
information_gain = None
split = None
for val in unique_values:
left = x <= val
right = x > val
left_data = self.data[left]
right_data = self.data[right]
left_impurity = self.__calculate_impurity_score(left_data[self.target])
right_impurity = self.__calculate_impurity_score(right_data[self.target])
score = self.__calculate_information_gain(left_count = len(left_data),
left_impurity = left_impurity,
right_count = len(right_data),
right_impurity = right_impurity)
if information_gain is None or score > information_gain:
information_gain = score
split = val
return information_gain, split
def __calculate_information_gain(self, left_count, left_impurity, right_count, right_impurity):
return self.impurity_score - ((left_count/len(self.data)) * left_impurity + \
(right_count/len(self.data)) * right_impurity)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment