Skip to content

Instantly share code, notes, and snippets.

@CVxTz
Last active February 12, 2019 21:48
Show Gist options
  • Save CVxTz/6a9c629d7ba57cfe96bca0d2170481b7 to your computer and use it in GitHub Desktop.
Save CVxTz/6a9c629d7ba57cfe96bca0d2170481b7 to your computer and use it in GitHub Desktop.
def get_graph_embedding_model(n_nodes):
in_1 = Input((1,))
in_2 = Input((1,))
emb = Embedding(n_nodes, 100, name="node1")
x1 = emb(in_1)
x2 = emb(in_2)
x1 = Flatten()(x1)
x1 = Dropout(0.1)(x1)
x2 = Flatten()(x2)
x2 = Dropout(0.1)(x2)
x = Multiply()([x1, x2])
x = Dropout(0.1)(x)
x = Dense(1, activation="linear", name="spl")(x)
model = Model([in_1, in_2], x)
model.compile(loss="mae", optimizer="adam")
model.summary()
return model
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment