Created
July 8, 2016 01:35
-
-
Save nikitakit/d3ec270aee9d930267cec3efa844d5aa to your computer and use it in GitHub Desktop.
Opening Tensorflow models in scala
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
sess = tf.InteractiveSession() | |
# [graph creation here] | |
with open(os.path.join(MODEL_DIR, "model.pb"), "wb") as f: | |
f.write(sess.graph_def.SerializeToString()) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package epic.tensorflow | |
import org.bytedeco.javacpp.{tensorflow, BytePointer} | |
import org.bytedeco.javacpp.tensorflow.GraphDef | |
/** | |
* Created by kitaev on 1/22/16. | |
*/ | |
class TensorflowModel(filename: String) { | |
val graphdef = new GraphDef() | |
val graphdef_data = java.nio.file.Files.readAllBytes(java.nio.file.Paths.get(filename)) | |
val graphdef_data_ptr = new BytePointer(java.nio.ByteBuffer.wrap(graphdef_data)) | |
val import_ok = tensorflow.ParseProtoUnlimited(graphdef, graphdef_data_ptr) | |
if (!import_ok) { | |
throw new Exception(s"Graphdef import failed from: ${filename}") | |
} | |
graphdef.set_version(0) // TODO(nikita): figure out how versioning is supposed to work | |
val s = TensorflowSession.sess.Extend(graphdef) | |
if (!s.ok()) { | |
throw new Exception(s"${s.ToString().getString()}") | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package epic.tensorflow | |
import org.bytedeco.javacpp.tensorflow | |
import org.bytedeco.javacpp.tensorflow._ | |
/** | |
* Created by kitaev on 1/22/16. | |
*/ | |
object TensorflowSession { | |
val sess = new tensorflow.Session(new SessionOptions()) | |
private val graphdef = new tensorflow.GraphDef() | |
sess.Create(graphdef) | |
// [project-specific code redacted] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment