Skip to content

Instantly share code, notes, and snippets.

@janpfeifer
Last active February 19, 2025 06:41
Show Gist options
  • Save janpfeifer/9d7f00f4c1129efe3417ffb3edd754e6 to your computer and use it in GitHub Desktop.
Save janpfeifer/9d7f00f4c1129efe3417ffb3edd754e6 to your computer and use it in GitHub Desktop.
Safe cosine similarity
// CosineSimilarity ....
// lhs -> left-hand side
// rhs -> right-hand side
func CosineSimilarity(lhs, rhs *Node) *Node {
g := lhs.Graph()
dtype := lhs.DType()
axis := -1 // Axis over which to calculate the cosine.
// Mask for rows that are fully zero, for which cosine similary is not normally defined.
lhsMask := ReduceAndKeep(IsZero(lhs), ReduceLogicalAnd, axis),
rhsMask := ReduceAndKeep(IsZero(rhs), ReduceLogicalAnd, axis)
// Recover original shape, by broadcasting the mask where we just reduced.
lhsMask = BroadcastToShape(lhs, lhs.Shape())
rhsMask = BroadcastToShape(rhs, rhs.Shape())
// Replace rows with all zeroes (lhsMask/rhsMask) with 1.
// Any positive numerical safe number would work, since the final computation for
// those rows won't be used, as long as they are not NaNs.
one := ScalarOne(g, dtype)
lhs = Where(lhsMask, lhs, one)
rhs = Where(lhsMask, rhs, one)
// ... calculate similarity as usual
// Arbitrarily set the similarity of the zero-rows (lhsMask or rhsMask) to zero.
zero := ScalarZero(g, dtype)
similarity = Where(LogicalOr(lhsMask, rhsMask), zero, similarity)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment