Last active
July 2, 2019 09:42
-
-
Save hayabusabusa/5d09705bd29ab22df76c81c8bb2512e2 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// | |
// FirebaseProvider.swift | |
// MlKitSample | |
// | |
import Foundation | |
import Firebase | |
import FirebaseMLCommon | |
final class FirebaseMLKitProvider { | |
static let shared: FirebaseMLKitProvider = FirebaseMLKitProvider() | |
// Private property | |
private var interpreter: ModelInterpreter | |
private init() { | |
// Read TensorFlow Lite custom model | |
let initialConditions = ModelDownloadConditions( | |
allowsCellularAccess: true, | |
allowsBackgroundDownloading: true | |
) | |
let updateConditions = ModelDownloadConditions( | |
allowsCellularAccess: false, | |
allowsBackgroundDownloading: true | |
) | |
let remoteModel = RemoteModel( | |
name: "test_gender_model", | |
allowsModelUpdates: true, | |
initialConditions: initialConditions, | |
updateConditions: updateConditions | |
) | |
guard ModelManager.modelManager().register(remoteModel) else { | |
fatalError("MLKit custom tflite register failed.") | |
} | |
// Create interpreter | |
let option = ModelOptions(remoteModelName: "test_gender_model", localModelName: nil) | |
interpreter = ModelInterpreter.modelInterpreter(options: option) | |
} | |
private func createInputData(image: UIImage) -> Data { | |
// 画像をリサイズして取得( UIImage -> CGImage に変換 ) | |
// UIImage は point で CGImage は pixel なのに注意 | |
guard let cgImage = image.resizeToPixelSize(width: 224.0, height: 224.0).cgImage else { | |
fatalError("Failed to convert CGImage.") | |
} | |
// TEST: UIImage で確認 | |
let uiImage = UIImage(cgImage: cgImage, scale: 0, orientation: image.imageOrientation) | |
print("CGImage: width: \(cgImage.width), height: \(cgImage.height)") | |
print("UIImage: width: \(uiImage.size.width), height: \(uiImage.size.height)") | |
// CoreGraphics?(swift用に作られてる画像クラスの一つ) | |
// 画像を格納するための配列を定義 | |
// https://developer.apple.com/documentation/coregraphics/cgcontext/1455939-init | |
guard let context = CGContext( | |
data: nil, | |
width: cgImage.width, | |
height: cgImage.height, | |
bitsPerComponent: 8, | |
bytesPerRow: 0, //cgImage.width * 4, // data が nil の時 0 にすると自動計算になる | |
space: CGColorSpaceCreateDeviceRGB(), | |
bitmapInfo: CGImageAlphaInfo.noneSkipFirst.rawValue | |
) else { | |
return Data() | |
} | |
// ここで定義されたcontextにデータを書き写す | |
// 見た感じここはリサイズではなくトリミングっぽい | |
context.draw(cgImage, in: CGRect(x: 0, y: 0, width: cgImage.width, height: cgImage.height)) | |
guard let imageData = context.data else { return Data() } | |
var inputData = Data() | |
// ここで画像の正規化をしてる,画像のとる値を0-255から0-1に変換してる | |
for row in 0 ..< 224 { | |
for col in 0 ..< 224 { | |
let offset = 4 * (col * context.width + row) | |
// (Ignore offset 0, the unused alpha channel) | |
let red = imageData.load(fromByteOffset: offset+1, as: UInt8.self) | |
let green = imageData.load(fromByteOffset: offset+2, as: UInt8.self) | |
let blue = imageData.load(fromByteOffset: offset+3, as: UInt8.self) | |
// Normalize channel values to [0.0, 1.0]. This requirement varies | |
// by model. For example, some models might require values to be | |
// normalized to the range [-1.0, 1.0] instead, and others might | |
// require fixed-point values or the original bytes. | |
var normalizedRed = Float32(red) / 255.0 | |
var normalizedGreen = Float32(green) / 255.0 | |
var normalizedBlue = Float32(blue) / 255.0 | |
// Append normalized values to Data object in RGB order. | |
let elementSize = MemoryLayout.size(ofValue: normalizedRed) | |
var bytes = [UInt8](repeating: 0, count: elementSize) | |
memcpy(&bytes, &normalizedRed, elementSize) | |
inputData.append(&bytes, count: elementSize) | |
memcpy(&bytes, &normalizedGreen, elementSize) | |
inputData.append(&bytes, count: elementSize) | |
memcpy(&bytes, &normalizedBlue, elementSize) | |
inputData.append(&bytes, count: elementSize) | |
} | |
} | |
print(inputData) | |
return inputData | |
} | |
func startInference(image: UIImage, completion: @escaping (_ model: GenderModel) -> Void ) { | |
// Setting input and output | |
let ioOptions = ModelInputOutputOptions() | |
do { | |
try ioOptions.setInputFormat(index: 0, type: .float32, dimensions: [1, 224, 224, 3]) | |
try ioOptions.setOutputFormat(index: 0, type: .float32, dimensions: [1, 2]) | |
} catch let error as NSError { | |
print("Failed to set input or output format with error: \(error.localizedDescription)") | |
} | |
// Setup input data | |
let inputs = ModelInputs() | |
do { | |
try inputs.addInput(createInputData(image: image)) | |
} catch let error { | |
print("Failed to add input: \(error)") | |
} | |
// Run interpreter | |
interpreter.run(inputs: inputs, options: ioOptions) { outputs, error in | |
guard error == nil, let outputs = outputs else { | |
print(error?.localizedDescription ?? "nil") | |
return | |
} | |
guard let output = try? outputs.output(index: 0) as? [[NSNumber]] else { | |
return | |
} | |
print("Output: \(output)") | |
completion(GenderModel(male: output[0][0].doubleValue, female: output[0][1].doubleValue)) | |
} | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import UIKit | |
extension UIImage { | |
func resizeToPixelSize(width : CGFloat, height : CGFloat)-> UIImage { | |
let size = CGSize(width: width, height: height) | |
UIGraphicsBeginImageContextWithOptions(size, false, 1.0) | |
UIGraphicsGetCurrentContext() | |
self.draw(in: CGRect(x: 0, y: 0, width: size.width, height: size.height)) | |
let image = UIGraphicsGetImageFromCurrentImageContext() | |
UIGraphicsEndImageContext() | |
return image ?? self | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment