Skip to content

Instantly share code, notes, and snippets.

@hayabusabusa
Last active July 2, 2019 09:42
Show Gist options
  • Save hayabusabusa/5d09705bd29ab22df76c81c8bb2512e2 to your computer and use it in GitHub Desktop.
Save hayabusabusa/5d09705bd29ab22df76c81c8bb2512e2 to your computer and use it in GitHub Desktop.
//
// FirebaseProvider.swift
// MlKitSample
//
import Foundation
import Firebase
import FirebaseMLCommon
final class FirebaseMLKitProvider {
static let shared: FirebaseMLKitProvider = FirebaseMLKitProvider()
// Private property
private var interpreter: ModelInterpreter
private init() {
// Read TensorFlow Lite custom model
let initialConditions = ModelDownloadConditions(
allowsCellularAccess: true,
allowsBackgroundDownloading: true
)
let updateConditions = ModelDownloadConditions(
allowsCellularAccess: false,
allowsBackgroundDownloading: true
)
let remoteModel = RemoteModel(
name: "test_gender_model",
allowsModelUpdates: true,
initialConditions: initialConditions,
updateConditions: updateConditions
)
guard ModelManager.modelManager().register(remoteModel) else {
fatalError("MLKit custom tflite register failed.")
}
// Create interpreter
let option = ModelOptions(remoteModelName: "test_gender_model", localModelName: nil)
interpreter = ModelInterpreter.modelInterpreter(options: option)
}
private func createInputData(image: UIImage) -> Data {
// 画像をリサイズして取得( UIImage -> CGImage に変換 )
// UIImage は point で CGImage は pixel なのに注意
guard let cgImage = image.resizeToPixelSize(width: 224.0, height: 224.0).cgImage else {
fatalError("Failed to convert CGImage.")
}
// TEST: UIImage で確認
let uiImage = UIImage(cgImage: cgImage, scale: 0, orientation: image.imageOrientation)
print("CGImage: width: \(cgImage.width), height: \(cgImage.height)")
print("UIImage: width: \(uiImage.size.width), height: \(uiImage.size.height)")
// CoreGraphics?(swift用に作られてる画像クラスの一つ)
// 画像を格納するための配列を定義
// https://developer.apple.com/documentation/coregraphics/cgcontext/1455939-init
guard let context = CGContext(
data: nil,
width: cgImage.width,
height: cgImage.height,
bitsPerComponent: 8,
bytesPerRow: 0, //cgImage.width * 4, // data が nil の時 0 にすると自動計算になる
space: CGColorSpaceCreateDeviceRGB(),
bitmapInfo: CGImageAlphaInfo.noneSkipFirst.rawValue
) else {
return Data()
}
// ここで定義されたcontextにデータを書き写す
// 見た感じここはリサイズではなくトリミングっぽい
context.draw(cgImage, in: CGRect(x: 0, y: 0, width: cgImage.width, height: cgImage.height))
guard let imageData = context.data else { return Data() }
var inputData = Data()
// ここで画像の正規化をしてる,画像のとる値を0-255から0-1に変換してる
for row in 0 ..< 224 {
for col in 0 ..< 224 {
let offset = 4 * (col * context.width + row)
// (Ignore offset 0, the unused alpha channel)
let red = imageData.load(fromByteOffset: offset+1, as: UInt8.self)
let green = imageData.load(fromByteOffset: offset+2, as: UInt8.self)
let blue = imageData.load(fromByteOffset: offset+3, as: UInt8.self)
// Normalize channel values to [0.0, 1.0]. This requirement varies
// by model. For example, some models might require values to be
// normalized to the range [-1.0, 1.0] instead, and others might
// require fixed-point values or the original bytes.
var normalizedRed = Float32(red) / 255.0
var normalizedGreen = Float32(green) / 255.0
var normalizedBlue = Float32(blue) / 255.0
// Append normalized values to Data object in RGB order.
let elementSize = MemoryLayout.size(ofValue: normalizedRed)
var bytes = [UInt8](repeating: 0, count: elementSize)
memcpy(&bytes, &normalizedRed, elementSize)
inputData.append(&bytes, count: elementSize)
memcpy(&bytes, &normalizedGreen, elementSize)
inputData.append(&bytes, count: elementSize)
memcpy(&bytes, &normalizedBlue, elementSize)
inputData.append(&bytes, count: elementSize)
}
}
print(inputData)
return inputData
}
func startInference(image: UIImage, completion: @escaping (_ model: GenderModel) -> Void ) {
// Setting input and output
let ioOptions = ModelInputOutputOptions()
do {
try ioOptions.setInputFormat(index: 0, type: .float32, dimensions: [1, 224, 224, 3])
try ioOptions.setOutputFormat(index: 0, type: .float32, dimensions: [1, 2])
} catch let error as NSError {
print("Failed to set input or output format with error: \(error.localizedDescription)")
}
// Setup input data
let inputs = ModelInputs()
do {
try inputs.addInput(createInputData(image: image))
} catch let error {
print("Failed to add input: \(error)")
}
// Run interpreter
interpreter.run(inputs: inputs, options: ioOptions) { outputs, error in
guard error == nil, let outputs = outputs else {
print(error?.localizedDescription ?? "nil")
return
}
guard let output = try? outputs.output(index: 0) as? [[NSNumber]] else {
return
}
print("Output: \(output)")
completion(GenderModel(male: output[0][0].doubleValue, female: output[0][1].doubleValue))
}
}
}
import UIKit
extension UIImage {
func resizeToPixelSize(width : CGFloat, height : CGFloat)-> UIImage {
let size = CGSize(width: width, height: height)
UIGraphicsBeginImageContextWithOptions(size, false, 1.0)
UIGraphicsGetCurrentContext()
self.draw(in: CGRect(x: 0, y: 0, width: size.width, height: size.height))
let image = UIGraphicsGetImageFromCurrentImageContext()
UIGraphicsEndImageContext()
return image ?? self
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment