Skip to content

Instantly share code, notes, and snippets.

@dmitryhd
Created March 19, 2019 20:23
Show Gist options
  • Save dmitryhd/644e94e1b5872395f3e58413a7e6f38c to your computer and use it in GitHub Desktop.
Save dmitryhd/644e94e1b5872395f3e58413a7e6f38c to your computer and use it in GitHub Desktop.
class LabelEncoder:
def __init__(self):
self._values2index = {}
self._next_index = 0
self._index2value = {}
def encode(self, value) -> int:
if value in self._values2index:
return self._values2index[value]
# new
self._values2index[value] = self._next_index
self._index2value[self._next_index] = value
self._next_index += 1
return self._next_index - 1
def decode(self, index: int):
if index in self._index2value:
return self._index2value[index]
return None
def encode_batch(self, values: list) -> list:
return [self.encode(val) for val in values]
def decode_batch(self, indices: list) -> list:
return [self.decode(i) for i in indices]
x = ['a', 'b', 'c', 'd', 'a']
l = LabelEncoder()
l.encode_batch(x)
l.decode_batch([0, 1, 2, 3, 0])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment