From 3c1552c90ffdccec9a63dbf20f948ec4703bb13a Mon Sep 17 00:00:00 2001 From: ifielker Date: Wed, 30 Sep 2020 17:12:18 -0400 Subject: [PATCH] Added support for Core ML models (#1051) * Added support for Core ML models --- package-lock.json | 16 +- src/index.d.ts | 29 +- .../machine-learning-api-client.ts | 25 +- src/machine-learning/machine-learning.ts | 53 +++- .../machine-learning-api-client.spec.ts | 168 ++++++++++-- .../machine-learning/machine-learning.spec.ts | 257 ++++++++++++++---- 6 files changed, 461 insertions(+), 87 deletions(-) diff --git a/package-lock.json b/package-lock.json index f17615e944..52890aef14 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,6 +1,6 @@ { "name": "firebase-admin", - "version": "9.1.1", + "version": "9.2.0", "lockfileVersion": 1, "requires": true, "dependencies": { @@ -275,9 +275,9 @@ } }, "@google-cloud/common": { - "version": "3.3.3", - "resolved": "https://registry.npmjs.org/@google-cloud/common/-/common-3.3.3.tgz", - "integrity": "sha512-2PwPDE47N4WiWQK/F35vE5aWVoCjKQ2NW8r8OFAg6QslkLMjX6WNcmUO8suYlSkavc58qOvzA4jG6eVkC90i8Q==", + "version": "3.4.0", + "resolved": "https://registry.npmjs.org/@google-cloud/common/-/common-3.4.0.tgz", + "integrity": "sha512-bVMQlK4aZEeopo2oJwDUJiBhPVjRRQHfFCCv9JowmKS3L//PBHNDJzC/LxJixGZEU3fh3YXkUwm67JZ5TBCCNQ==", "optional": true, "requires": { "@google-cloud/projectify": "^2.0.0", @@ -7653,14 +7653,14 @@ } }, "teeny-request": { - "version": "7.0.0", - "resolved": "https://registry.npmjs.org/teeny-request/-/teeny-request-7.0.0.tgz", - "integrity": "sha512-kWD3sdGmIix6w7c8ZdVKxWq+3YwVPGWz+Mq0wRZXayEKY/YHb63b8uphfBzcFDmyq8frD9+UTc3wLyOhltRbtg==", + "version": "7.0.1", + "resolved": "https://registry.npmjs.org/teeny-request/-/teeny-request-7.0.1.tgz", + "integrity": "sha512-sasJmQ37klOlplL4Ia/786M5YlOcoLGQyq2TE4WHSRupbAuDaQW0PfVxV4MtdBtRJ4ngzS+1qim8zP6Zp35qCw==", "optional": true, "requires": { "http-proxy-agent": "^4.0.0", "https-proxy-agent": "^5.0.0", - "node-fetch": "^2.2.0", + "node-fetch": "^2.6.1", "stream-events": "^1.0.5", "uuid": "^8.0.0" } diff --git a/src/index.d.ts b/src/index.d.ts index 997fcb8779..af8dfb7090 100644 --- a/src/index.d.ts +++ b/src/index.d.ts @@ -734,7 +734,12 @@ declare namespace admin.machineLearning { automlModel: string; }; } - type ModelOptions = ModelOptionsBase | GcsTfliteModelOptions | AutoMLTfliteModelOptions; + interface GcsCoremlModelOptions extends ModelOptionsBase { + coremlModel: { + gcsCoremlUri: string; + }; + } + type ModelOptions = ModelOptionsBase | GcsTfliteModelOptions | AutoMLTfliteModelOptions | GcsCoremlModelOptions; /** * A TensorFlow Lite Model output object @@ -753,6 +758,24 @@ declare namespace admin.machineLearning { * to Firebase. */ readonly automlModel?: string; + /** Indicates that the TFLite model was uploaded through the console. */ + readonly managedUpload?: boolean; + } + + /** + * A Core ML Model output object + */ + interface CoreMlModel { + /** The size of the model. */ + readonly sizeBytes: number; + + // One of the following 2 will be set: + + /** The URI from which the model was originally provided to Firebase. */ + readonly gcsCoremlUri?: string; + + /** Indicates that the Core ML model was uploaded through the console. */ + readonly managedUpload?: boolean; } /** @@ -820,8 +843,12 @@ declare namespace admin.machineLearning { */ toJSON(): {[key: string]: any}; + // At most one of the following will be specified. + /** Metadata about the model's TensorFlow Lite model file. */ readonly tfliteModel?: TFLiteModel; + /** Metadata about the model's Core ML model file. */ + readonly coremlModel?: CoreMlModel; } /** diff --git a/src/machine-learning/machine-learning-api-client.ts b/src/machine-learning/machine-learning-api-client.ts index e39c94db9a..f3549bfa1b 100644 --- a/src/machine-learning/machine-learning-api-client.ts +++ b/src/machine-learning/machine-learning-api-client.ts @@ -44,17 +44,27 @@ export interface ModelOptionsBase { displayName?: string; tags?: string[]; } + export interface GcsTfliteModelOptions extends ModelOptionsBase { tfliteModel: { gcsTfliteUri: string; }; } + export interface AutoMLTfliteModelOptions extends ModelOptionsBase { tfliteModel: { automlModel: string; }; } -export type ModelOptions = ModelOptionsBase | GcsTfliteModelOptions | AutoMLTfliteModelOptions; + +export interface GcsCoremlModelOptions extends ModelOptionsBase { + coremlModel: { + gcsCoremlUri: string; + }; +} + +export type ModelOptions = ModelOptionsBase | GcsTfliteModelOptions | AutoMLTfliteModelOptions | GcsCoremlModelOptions; + export type ModelUpdateOptions = ModelOptions & { state?: { published?: boolean }}; export function isGcsTfliteModelOptions(options: ModelOptions): options is GcsTfliteModelOptions { @@ -62,6 +72,12 @@ export function isGcsTfliteModelOptions(options: ModelOptions): options is GcsTf return typeof gcsUri !== 'undefined' } +export function isGcsCoremlModelOptions(options: ModelOptions): options is GcsCoremlModelOptions { + const gcsUri = (options as GcsCoremlModelOptions)?.coremlModel?.gcsCoremlUri; + return typeof gcsUri !== 'undefined' +} + + /** Interface representing listModels options. */ export interface ListModelsOptions { filter?: string; @@ -79,6 +95,13 @@ export interface ModelContent { readonly tfliteModel?: { readonly gcsTfliteUri?: string; readonly automlModel?: string; + readonly managedUpload?: boolean; + + readonly sizeBytes: number; + }; + readonly coremlModel?: { + readonly gcsCoremlUri?: string; + readonly managedUpload?: boolean; readonly sizeBytes: number; }; diff --git a/src/machine-learning/machine-learning.ts b/src/machine-learning/machine-learning.ts index 8430b66cfc..531b45526c 100644 --- a/src/machine-learning/machine-learning.ts +++ b/src/machine-learning/machine-learning.ts @@ -16,8 +16,8 @@ import { FirebaseApp } from '../firebase-app'; import { FirebaseServiceInterface, FirebaseServiceInternalsInterface } from '../firebase-service'; -import { MachineLearningApiClient, ModelResponse, ModelOptions, - ModelUpdateOptions, ListModelsOptions, isGcsTfliteModelOptions } from './machine-learning-api-client'; +import { MachineLearningApiClient, ModelResponse, ModelOptions, ModelUpdateOptions, + ListModelsOptions, isGcsTfliteModelOptions, isGcsCoremlModelOptions } from './machine-learning-api-client'; import { FirebaseError } from '../utils/error'; import * as validator from '../utils/validator'; @@ -207,6 +207,18 @@ export class MachineLearning implements FirebaseServiceInterface { `Error during signing upload url: ${err.message}`); }); } + if (isGcsCoremlModelOptions(modelOptions)) { + return this.signUrl(modelOptions.coremlModel.gcsCoremlUri) + .then((uri: string) => { + modelOptions.coremlModel.gcsCoremlUri = uri; + return modelOptions; + }) + .catch((err: Error) => { + throw new FirebaseMachineLearningError( + 'internal-error', + `Error during signing upload url: ${err.message}`); + }); + } return Promise.resolve(modelOptions); } @@ -285,6 +297,11 @@ export class Model { return deepCopy(this.model.tfliteModel); } + get coremlModel(): CoreMlModel | undefined { + // Make a copy so people can't directly modify the private this.model object. + return deepCopy(this.model.coremlModel); + } + /** * Locked indicates if there are active long running operations on the model. * Models may not be modified when they are locked. @@ -321,6 +338,10 @@ export class Model { jsonModel['tfliteModel'] = this.tfliteModel; } + if (this.coremlModel) { + jsonModel['coremlModel'] = this.coremlModel; + } + return jsonModel; } @@ -357,14 +378,24 @@ export class Model { const tmpModel = deepCopy(model); // If tflite Model is specified, it must have a source consisting of - // oneof {gcsTfliteUri, automlModel} + // oneof {gcsTfliteUri, automlModel, managedUpload} if (model.tfliteModel && !validator.isNonEmptyString(model.tfliteModel.gcsTfliteUri) && - !validator.isNonEmptyString(model.tfliteModel.automlModel)) { + !validator.isNonEmptyString(model.tfliteModel.automlModel) && + !model.tfliteModel.managedUpload) { // If we have some other source, ignore the whole tfliteModel. delete (tmpModel as any).tfliteModel; } + // If coreml Model is specified, it must have a source consisting of + // oneof {gcsCoremlUri, managedUpload} + if (model.coremlModel && + !validator.isNonEmptyString(model.coremlModel.gcsCoremlUri) && + !model.coremlModel.managedUpload) { + // If we have some other source, ignore the whole coremlModel. + delete (tmpModel as any).coremlModel; + } + // Remove '@type' field. We don't need it. if ((tmpModel as any)["@type"]) { delete (tmpModel as any)["@type"]; @@ -379,9 +410,21 @@ export class Model { export interface TFLiteModel { readonly sizeBytes: number; - // Oneof these two + // Oneof these three readonly gcsTfliteUri?: string; readonly automlModel?: string; + readonly managedUpload?: boolean; +} + +/** + * A Core ML Model output object + */ +export interface CoreMlModel { + readonly sizeBytes: number; + + // Oneof these two + readonly gcsCoremlUri?: string; + readonly managedUpload?: boolean; } function extractModelId(resourceName: string): string { diff --git a/test/unit/machine-learning/machine-learning-api-client.spec.ts b/test/unit/machine-learning/machine-learning-api-client.spec.ts index ae5fe16244..df8c258201 100644 --- a/test/unit/machine-learning/machine-learning-api-client.spec.ts +++ b/test/unit/machine-learning/machine-learning-api-client.spec.ts @@ -36,7 +36,7 @@ describe('MachineLearningApiClient', () => { const BASE_URL = 'https://firebaseml.googleapis.com/v1beta2'; const MODEL_ID = '1234567'; - const MODEL_RESPONSE = { + const TFLITE_GCS_MODEL_RESPONSE = { name: 'projects/test-project/models/1234567', createTime: '2020-02-07T23:45:23.288047Z', updateTime: '2020-02-08T23:45:23.288047Z', @@ -50,7 +50,7 @@ describe('MachineLearningApiClient', () => { sizeBytes: 16900988, }, }; - const MODEL_RESPONSE2 = { + const TFLITE_GCS_MODEL_RESPONSE2 = { name: 'projects/test-project/models/2345678', createTime: '2020-02-07T23:45:22.288047Z', updateTime: '2020-02-08T23:45:22.288047Z', @@ -64,7 +64,7 @@ describe('MachineLearningApiClient', () => { sizeBytes: 2220022, }, }; - const MODEL_RESPONSE_AUTOML = { + const TFLITE_AUTOML_RESPONSE = { name: 'projects/test-project/models/3456789', createTime: '2020-07-15T18:12:25.123987Z', updateTime: '2020-07-15T19:15:32.965435Z', @@ -78,6 +78,48 @@ describe('MachineLearningApiClient', () => { sizeBytes: 3330033, }, }; + const TFLITE_MANAGED_RESPONSE = { + name: 'projects/test-project/models/3456789', + createTime: '2020-07-15T18:12:25.123987Z', + updateTime: '2020-07-15T19:15:32.965435Z', + etag: 'etag345', + modelHash: 'modelHash345', + displayName: 'model_automl', + tags: ['tag_automl'], + state: { published: true }, + tfliteModel: { + managedUpload: true, + sizeBytes: 3330033, + }, + }; + const COREML_GCS_RESPONSE = { + name: 'projects/test-project/models/2345678', + createTime: '2020-02-07T23:45:22.288047Z', + updateTime: '2020-02-08T23:45:22.288047Z', + etag: 'etag234', + modelHash: 'modelHash234', + displayName: 'model_2', + tags: ['tag_2', 'tag_3'], + state: { published: true }, + coremlModel: { + gcsCoremlUri: 'gs://test-project-bucket/Firebase/ML/Models/model.mlmodel', + sizeBytes: 33300333, + }, + }; + const COREML_MANAGED_RESPONSE = { + name: 'projects/test-project/models/2345678', + createTime: '2020-02-07T23:45:22.288047Z', + updateTime: '2020-02-08T23:45:22.288047Z', + etag: 'etag234', + modelHash: 'modelHash234', + displayName: 'model_2', + tags: ['tag_2', 'tag_3'], + state: { published: true }, + coremlModel: { + managedUpload: true, + sizeBytes: 33300333, + }, + }; const PROJECT_ID = 'test-project'; const PROJECT_NUMBER = '1234567'; @@ -90,7 +132,7 @@ describe('MachineLearningApiClient', () => { }; const OPERATION_SUCCESS_RESPONSE = { done: true, - response: MODEL_RESPONSE, + response: TFLITE_GCS_MODEL_RESPONSE, }; const OPERATION_ERROR_RESPONSE = { done: true, @@ -107,7 +149,11 @@ describe('MachineLearningApiClient', () => { }; const OPERATION_AUTOML_RESPONSE = { done: true, - response: MODEL_RESPONSE_AUTOML, + response: TFLITE_AUTOML_RESPONSE, + }; + const OPERATION_COREML_GCS_RESPONSE = { + done: true, + response: COREML_GCS_RESPONSE, }; const LOCKED_MODEL_RESPONSE = { name: 'projects/test-project/models/1234567', @@ -173,7 +219,7 @@ describe('MachineLearningApiClient', () => { describe('createModel', () => { const NAME_ONLY_OPTIONS: ModelOptions = { displayName: 'name1' }; - const GCS_OPTIONS: ModelOptions = { + const TFLITE_GCS_OPTIONS: ModelOptions = { displayName: 'name2', tfliteModel: { gcsTfliteUri: 'gcsUri1', @@ -185,6 +231,12 @@ describe('MachineLearningApiClient', () => { automlModel: 'automlModel', }, }; + const COREML_GCS_OPTIONS: ModelOptions = { + displayName: 'coreml_gcs', + coremlModel: { + gcsCoremlUri: 'gcsUri2', + } + }; const invalidContent: any[] = [null, undefined, {}, { tags: [] }]; invalidContent.forEach((content) => { @@ -219,7 +271,7 @@ describe('MachineLearningApiClient', () => { .then((resp) => { expect(resp.done).to.be.true; expect(resp.name).to.be.undefined; - expect(resp.response).to.deep.equal(MODEL_RESPONSE); + expect(resp.response).to.deep.equal(TFLITE_GCS_MODEL_RESPONSE); }); }); @@ -228,11 +280,11 @@ describe('MachineLearningApiClient', () => { .stub(HttpClient.prototype, 'send') .resolves(utils.responseFrom(OPERATION_SUCCESS_RESPONSE)); stubs.push(stub); - return apiClient.createModel(GCS_OPTIONS) + return apiClient.createModel(TFLITE_GCS_OPTIONS) .then((resp) => { expect(resp.done).to.be.true; expect(resp.name).to.be.undefined; - expect(resp.response).to.deep.equal(MODEL_RESPONSE); + expect(resp.response).to.deep.equal(TFLITE_GCS_MODEL_RESPONSE); }); }); @@ -245,7 +297,20 @@ describe('MachineLearningApiClient', () => { .then((resp) => { expect(resp.done).to.be.true; expect(resp.name).to.be.undefined; - expect(resp.response).to.deep.equal(MODEL_RESPONSE_AUTOML); + expect(resp.response).to.deep.equal(TFLITE_AUTOML_RESPONSE); + }); + }); + + it('should accept Coreml GCS option', () => { + const stub = sinon + .stub(HttpClient.prototype, 'send') + .resolves(utils.responseFrom(OPERATION_COREML_GCS_RESPONSE)); + stubs.push(stub); + return apiClient.createModel(COREML_GCS_OPTIONS) + .then((resp) => { + expect(resp.done).to.be.true; + expect(resp.name).to.be.undefined; + expect(resp.response).to.deep.equal(COREML_GCS_RESPONSE); }); }); @@ -296,7 +361,7 @@ describe('MachineLearningApiClient', () => { describe('updateModel', () => { const NAME_ONLY_OPTIONS: ModelOptions = { displayName: 'name1' }; - const GCS_OPTIONS: ModelOptions = { + const TFLITE_GCS_OPTIONS: ModelOptions = { displayName: 'name2', tfliteModel: { gcsTfliteUri: 'gcsUri1', @@ -308,14 +373,22 @@ describe('MachineLearningApiClient', () => { automlModel: 'automlModel', }, }; + const COREML_GCS_OPTIONS: ModelOptions = { + displayName: 'coreml_gcs', + coremlModel: { + gcsCoremlUri: 'gcsUri2', + } + }; const NAME_ONLY_MASK_LIST = ['displayName']; - const GCS_MASK_LIST = ['displayName', 'tfliteModel.gcsTfliteUri']; + const TFLITE_GCS_MASK_LIST = ['displayName', 'tfliteModel.gcsTfliteUri']; const AUTOML_MASK_LIST = ['displayName', 'tfliteModel.automlModel']; + const COREML_GCS_MASK_LIST = ['displayName', 'coremlModel.gcsCoremlUri']; const NAME_ONLY_UPDATE_MASK_STRING = "updateMask=displayName"; - const GCS_UPDATE_MASK_STRING = "updateMask=displayName,tfliteModel.gcsTfliteUri"; + const TFLITE_GCS_UPDATE_MASK_STRING = "updateMask=displayName,tfliteModel.gcsTfliteUri"; const AUTOML_UPDATE_MASK_STRING = "updateMask=displayName,tfliteModel.automlModel"; + const COREML_UPDATE_MASK_STRING = "updateMask=displayName,coremlModel.gcsCoremlUri"; const invalidOptions: any[] = [null, undefined]; invalidOptions.forEach((option) => { @@ -356,7 +429,7 @@ describe('MachineLearningApiClient', () => { .then((resp) => { expect(resp.done).to.be.true; expect(resp.name).to.be.undefined; - expect(resp.response).to.deep.equal(MODEL_RESPONSE); + expect(resp.response).to.deep.equal(TFLITE_GCS_MODEL_RESPONSE); expect(stub).to.have.been.calledOnce.and.calledWith({ method: 'PATCH', headers: EXPECTED_HEADERS, @@ -366,21 +439,21 @@ describe('MachineLearningApiClient', () => { }); }); - it('should resolve with the updated GCS resource on success', () => { + it('should resolve with the updated Tflite GCS resource on success', () => { const stub = sinon .stub(HttpClient.prototype, 'send') .resolves(utils.responseFrom(OPERATION_SUCCESS_RESPONSE)); stubs.push(stub); - return apiClient.updateModel(MODEL_ID, GCS_OPTIONS, GCS_MASK_LIST) + return apiClient.updateModel(MODEL_ID, TFLITE_GCS_OPTIONS, TFLITE_GCS_MASK_LIST) .then((resp) => { expect(resp.done).to.be.true; expect(resp.name).to.be.undefined; - expect(resp.response).to.deep.equal(MODEL_RESPONSE); + expect(resp.response).to.deep.equal(TFLITE_GCS_MODEL_RESPONSE); expect(stub).to.have.been.calledOnce.and.calledWith({ method: 'PATCH', headers: EXPECTED_HEADERS, - url: `${BASE_URL}/projects/test-project/models/${MODEL_ID}?${GCS_UPDATE_MASK_STRING}`, - data: GCS_OPTIONS, + url: `${BASE_URL}/projects/test-project/models/${MODEL_ID}?${TFLITE_GCS_UPDATE_MASK_STRING}`, + data: TFLITE_GCS_OPTIONS, }); }); }); @@ -394,7 +467,7 @@ describe('MachineLearningApiClient', () => { .then((resp) => { expect(resp.done).to.be.true; expect(resp.name).to.be.undefined; - expect(resp.response).to.deep.equal(MODEL_RESPONSE); + expect(resp.response).to.deep.equal(TFLITE_GCS_MODEL_RESPONSE); expect(stub).to.have.been.calledOnce.and.calledWith({ method: 'PATCH', headers: EXPECTED_HEADERS, @@ -404,6 +477,25 @@ describe('MachineLearningApiClient', () => { }); }); + it('should resolve with the updated Coreml GCS resource on success', () => { + const stub = sinon + .stub(HttpClient.prototype, 'send') + .resolves(utils.responseFrom(OPERATION_COREML_GCS_RESPONSE)); + stubs.push(stub); + return apiClient.updateModel(MODEL_ID, COREML_GCS_OPTIONS, COREML_GCS_MASK_LIST) + .then((resp) => { + expect(resp.done).to.be.true; + expect(resp.name).to.be.undefined; + expect(resp.response).to.deep.equal(COREML_GCS_RESPONSE); + expect(stub).to.have.been.calledOnce.and.calledWith({ + method: 'PATCH', + headers: EXPECTED_HEADERS, + url: `${BASE_URL}/projects/test-project/models/${MODEL_ID}?${COREML_UPDATE_MASK_STRING}`, + data: COREML_GCS_OPTIONS, + }); + }); + }); + it('should resolve with error when the operation fails', () => { const stub = sinon .stub(HttpClient.prototype, 'send') @@ -486,6 +578,32 @@ describe('MachineLearningApiClient', () => { }); }); + it('should resolve with Model for tflite managed model response', () => { + const stub = sinon + .stub(HttpClient.prototype, 'send') + .resolves(utils.responseFrom(TFLITE_MANAGED_RESPONSE)); + stubs.push(stub); + return apiClient.getModel(MODEL_ID) + .then((resp) => { + expect(resp.name).to.equal('projects/test-project/models/3456789'); + expect(resp.tfliteModel?.managedUpload).to.be.true; + expect(resp.tfliteModel?.sizeBytes).to.be.equal(3330033); + }) + }); + + it('should resolve with Model for coreml managed model response', () => { + const stub = sinon + .stub(HttpClient.prototype, 'send') + .resolves(utils.responseFrom(COREML_MANAGED_RESPONSE)); + stubs.push(stub); + return apiClient.getModel(MODEL_ID) + .then((resp) => { + expect(resp.name).to.equal('projects/test-project/models/2345678'); + expect(resp.coremlModel?.managedUpload).to.be.true; + expect(resp.coremlModel?.sizeBytes).to.be.equal(33300333); + }) + }); + it('should reject when a full platform error response is received', () => { const stub = sinon .stub(HttpClient.prototype, 'send') @@ -591,7 +709,7 @@ describe('MachineLearningApiClient', () => { it('handles a done operation with result', () => { return apiClient.handleOperation(OPERATION_SUCCESS_RESPONSE) .then((resp) => { - expect(resp).deep.equals(MODEL_RESPONSE); + expect(resp).deep.equals(TFLITE_GCS_MODEL_RESPONSE); }); }); @@ -628,7 +746,7 @@ describe('MachineLearningApiClient', () => { baseWaitMillis: 2, maxWaitMillis: 5 }) .then((resp) => { - expect(resp).to.deep.equal(MODEL_RESPONSE); + expect(resp).to.deep.equal(TFLITE_GCS_MODEL_RESPONSE); expect(stub).to.have.been.calledTwice.and.calledWith({ method: 'GET', url: `${BASE_URL}/projects/${PROJECT_NUMBER}/operations/${OPERATION_ID}`, @@ -677,7 +795,7 @@ describe('MachineLearningApiClient', () => { describe('listModels', () => { const LIST_RESPONSE = { - models: [MODEL_RESPONSE, MODEL_RESPONSE2], + models: [TFLITE_GCS_MODEL_RESPONSE, TFLITE_GCS_MODEL_RESPONSE2], nextPageToken: 'next', }; @@ -754,8 +872,8 @@ describe('MachineLearningApiClient', () => { .then((resp) => { expect(resp.models).not.to.be.empty; expect(resp.models!.length).to.equal(2); - expect(resp.models![0]).to.deep.equal(MODEL_RESPONSE); - expect(resp.models![1]).to.deep.equal(MODEL_RESPONSE2); + expect(resp.models![0]).to.deep.equal(TFLITE_GCS_MODEL_RESPONSE); + expect(resp.models![1]).to.deep.equal(TFLITE_GCS_MODEL_RESPONSE2); expect(stub).to.have.been.calledOnce.and.calledWith({ method: 'GET', url: `${BASE_URL}/projects/test-project/models`, diff --git a/test/unit/machine-learning/machine-learning.spec.ts b/test/unit/machine-learning/machine-learning.spec.ts index 8d32619dff..a83fa6d7fb 100644 --- a/test/unit/machine-learning/machine-learning.spec.ts +++ b/test/unit/machine-learning/machine-learning.spec.ts @@ -39,7 +39,7 @@ describe('MachineLearning', () => { const EXPECTED_ERROR = new FirebaseMachineLearningError('internal-error', 'message'); const CREATE_TIME_UTC = 'Fri, 07 Feb 2020 23:45:23 GMT'; const UPDATE_TIME_UTC = 'Sat, 08 Feb 2020 23:45:23 GMT'; - const MODEL_RESPONSE: { + const TFLITE_GCS_MODEL_RESPONSE: { name: string; createTime: string; updateTime: string; @@ -74,7 +74,7 @@ describe('MachineLearning', () => { }; - const MODEL_RESPONSE2: { + const TFLITE_GCS_MODEL_RESPONSE2: { name: string; createTime: string; updateTime: string; @@ -108,7 +108,22 @@ describe('MachineLearning', () => { }, }; - const MODEL_RESPONSE3: any = { + const TFLITE_AUTOML_RESPONSE: any = { + name: 'projects/test-project/models/3456789', + createTime: '2020-02-07T23:45:23.288047Z', + updateTime: '2020-02-08T23:45:23.288047Z', + etag: 'etag345', + modelHash: 'modelHash345', + displayName: 'model_3', + tags: ['tag_3', 'tag_4'], + state: { published: true }, + tfliteModel: { + automlModel: 'projects/12456/locations/us-central1/model/ICN123456', + sizeBytes: 22200222, + }, + }; + + const TFLITE_MANAGED_RESPONSE: any = { name: 'projects/test-project/models/3456789', createTime: '2020-02-07T23:45:23.288047Z', updateTime: '2020-02-08T23:45:23.288047Z', @@ -123,6 +138,67 @@ describe('MachineLearning', () => { }, }; + const TFLITE_UNKNOWN_SOURCE_RESPONSE: any = { + name: 'projects/test-project/models/3456789', + createTime: '2020-02-07T23:45:23.288047Z', + updateTime: '2020-02-08T23:45:23.288047Z', + etag: 'etag345', + modelHash: 'modelHash345', + displayName: 'model_3', + tags: ['tag_3', 'tag_4'], + state: { published: true }, + tfliteModel: { + foo: true, + sizeBytes: 22200222, + }, + }; + + const COREML_GCS_RESPONSE: any = { + name: 'projects/test-project/models/3456789', + createTime: '2020-02-07T23:45:23.288047Z', + updateTime: '2020-02-08T23:45:23.288047Z', + etag: 'etag345', + modelHash: 'modelHash345', + displayName: 'model_3', + tags: ['tag_3', 'tag_4'], + state: { published: true }, + coremlModel: { + gcsCoremlUri: 'gs://test-project-bucket/Firebase/ML/Models/model6.mlmodel', + sizeBytes: 33300333, + }, + }; + + const COREML_MANAGED_RESPONSE: any = { + name: 'projects/test-project/models/3456789', + createTime: '2020-02-07T23:45:23.288047Z', + updateTime: '2020-02-08T23:45:23.288047Z', + etag: 'etag345', + modelHash: 'modelHash345', + displayName: 'model_3', + tags: ['tag_3', 'tag_4'], + state: { published: true }, + coremlModel: { + managedUpload: true, + sizeBytes: 33300333, + }, + }; + + const COREML_UNKNOWN_SOURCE_RESPONSE: any = { + name: 'projects/test-project/models/3456789', + createTime: '2020-02-07T23:45:23.288047Z', + updateTime: '2020-02-08T23:45:23.288047Z', + etag: 'etag345', + modelHash: 'modelHash345', + displayName: 'model_3', + tags: ['tag_3', 'tag_4'], + state: { published: true }, + coremlModel: { + foo: true, + sizeBytes: 33300333, + }, + } + + const STATUS_ERROR_RESPONSE: { code: number; message: string; @@ -131,7 +207,7 @@ describe('MachineLearning', () => { message: 'Invalid Argument message', }; - const OPERATION_RESPONSE: { + const TFLITE_GCS_OPERATION_RESPONSE: { name?: string; metadata?: any; done: boolean; @@ -158,7 +234,7 @@ describe('MachineLearning', () => { }; } = { done: true, - response: MODEL_RESPONSE, + response: TFLITE_GCS_MODEL_RESPONSE, }; const OPERATION_RESPONSE_ERROR: { @@ -246,8 +322,8 @@ describe('MachineLearning', () => { mockClient = new MachineLearningApiClient(mockApp); mockCredentialApp = mocks.mockCredentialApp(); machineLearning = new MachineLearning(mockApp); - model1 = new Model(MODEL_RESPONSE, mockClient); - model2 = new Model(MODEL_RESPONSE2, mockClient); + model1 = new Model(TFLITE_GCS_MODEL_RESPONSE, mockClient); + model2 = new Model(TFLITE_GCS_MODEL_RESPONSE2, mockClient); }); after(() => { @@ -309,8 +385,8 @@ describe('MachineLearning', () => { }); describe('Model', () => { - it('should successfully construct a model', () => { - const model = new Model(MODEL_RESPONSE, mockClient); + it('should successfully construct a gcs tflite model', () => { + const model = new Model(TFLITE_GCS_MODEL_RESPONSE, mockClient); expect(model.modelId).to.equal(MODEL_ID); expect(model.displayName).to.equal('model_1'); expect(model.tags).to.deep.equal(['tag_1', 'tag_2']); @@ -325,10 +401,11 @@ describe('MachineLearning', () => { expect(tflite.gcsTfliteUri).to.be.equal( 'gs://test-project-bucket/Firebase/ML/Models/model1.tflite'); expect(tflite.sizeBytes).to.be.equal(16900988); + expect(model.coremlModel).to.be.undefined; }); - it('should accept unknown fields gracefully', () => { - const model = new Model(MODEL_RESPONSE3, mockClient); + it('should successfully construct an automl tflite model', () => { + const model = new Model(TFLITE_AUTOML_RESPONSE, mockClient); expect(model.modelId).to.equal('3456789'); expect(model.displayName).to.equal('model_3'); expect(model.tags).to.deep.equal(['tag_3', 'tag_4']); @@ -338,11 +415,97 @@ describe('MachineLearning', () => { expect(model.published).to.be.true; expect(model.etag).to.equal('etag345'); expect(model.modelHash).to.equal('modelHash345'); + const tflite = model.tfliteModel!; + expect(tflite.automlModel).to.be.equal( + 'projects/12456/locations/us-central1/model/ICN123456'); + expect(tflite.sizeBytes).to.be.equal(22200222); + expect(model.coremlModel).to.be.undefined; + }); + + it('should successfully construct a managedUpload tflite model', () => { + const model = new Model(TFLITE_MANAGED_RESPONSE, mockClient); + expect(model.modelId).to.equal('3456789'); + expect(model.displayName).to.equal('model_3'); + expect(model.tags).to.deep.equal(['tag_3', 'tag_4']); + expect(model.createTime).to.equal(CREATE_TIME_UTC); + expect(model.updateTime).to.equal(UPDATE_TIME_UTC); + expect(model.validationError).to.be.undefined; + expect(model.published).to.be.true; + expect(model.etag).to.equal('etag345'); + expect(model.modelHash).to.equal('modelHash345'); + const tflite = model.tfliteModel!; + expect(tflite.managedUpload).to.be.true; + expect(tflite.sizeBytes).to.be.equal(22200222); + expect(model.coremlModel).to.be.undefined; + }); + + it('should accept unknown tflite fields gracefully', () => { + const model = new Model(TFLITE_UNKNOWN_SOURCE_RESPONSE, mockClient); + expect(model.modelId).to.equal('3456789'); + expect(model.displayName).to.equal('model_3'); + expect(model.tags).to.deep.equal(['tag_3', 'tag_4']); + expect(model.createTime).to.equal(CREATE_TIME_UTC); + expect(model.updateTime).to.equal(UPDATE_TIME_UTC); + expect(model.validationError).to.be.undefined; + expect(model.published).to.be.true; + expect(model.etag).to.equal('etag345'); + expect(model.modelHash).to.equal('modelHash345'); + expect(model.tfliteModel).to.be.undefined; + expect(model.coremlModel).to.be.undefined; + }); + + it('should successfully construct a gcs coreml model', () => { + const model = new Model(COREML_GCS_RESPONSE, mockClient); + expect(model.modelId).to.equal('3456789'); + expect(model.displayName).to.equal('model_3'); + expect(model.tags).to.deep.equal(['tag_3', 'tag_4']); + expect(model.createTime).to.equal(CREATE_TIME_UTC); + expect(model.updateTime).to.equal(UPDATE_TIME_UTC); + expect(model.validationError).to.be.undefined; + expect(model.published).to.be.true; + expect(model.etag).to.equal('etag345'); + expect(model.modelHash).to.equal('modelHash345'); + const coreml = model.coremlModel!; + expect(coreml.gcsCoremlUri).to.be.equal( + 'gs://test-project-bucket/Firebase/ML/Models/model6.mlmodel'); + expect(coreml.sizeBytes).to.be.equal(33300333); + expect(model.tfliteModel).to.be.undefined; + }); + + it('should successfully construct a managedUpload coreml model', () => { + const model = new Model(COREML_MANAGED_RESPONSE, mockClient); + expect(model.modelId).to.equal('3456789'); + expect(model.displayName).to.equal('model_3'); + expect(model.tags).to.deep.equal(['tag_3', 'tag_4']); + expect(model.createTime).to.equal(CREATE_TIME_UTC); + expect(model.updateTime).to.equal(UPDATE_TIME_UTC); + expect(model.validationError).to.be.undefined; + expect(model.published).to.be.true; + expect(model.etag).to.equal('etag345'); + expect(model.modelHash).to.equal('modelHash345'); + const coreml = model.coremlModel!; + expect(coreml.managedUpload).to.be.true; + expect(coreml.sizeBytes).to.be.equal(33300333); expect(model.tfliteModel).to.be.undefined; }); + it('should accept unknown coreml fields gracefully', () => { + const model = new Model(COREML_UNKNOWN_SOURCE_RESPONSE, mockClient); + expect(model.modelId).to.equal('3456789'); + expect(model.displayName).to.equal('model_3'); + expect(model.tags).to.deep.equal(['tag_3', 'tag_4']); + expect(model.createTime).to.equal(CREATE_TIME_UTC); + expect(model.updateTime).to.equal(UPDATE_TIME_UTC); + expect(model.validationError).to.be.undefined; + expect(model.published).to.be.true; + expect(model.etag).to.equal('etag345'); + expect(model.modelHash).to.equal('modelHash345'); + expect(model.tfliteModel).to.be.undefined; + expect(model.coremlModel).to.be.undefined; + }); + it('should successfully serialize a model to JSON', () => { - const model = new Model(MODEL_RESPONSE, mockClient); + const model = new Model(TFLITE_GCS_MODEL_RESPONSE, mockClient); const expectedModel = { modelId: MODEL_ID, displayName: 'model_1', @@ -368,7 +531,7 @@ describe('MachineLearning', () => { }); it('should return locked as false when no active operations are present', () => { - const model = new Model(MODEL_RESPONSE, mockClient); + const model = new Model(TFLITE_GCS_MODEL_RESPONSE, mockClient); expect(model.locked).to.be.false; }); @@ -378,7 +541,7 @@ describe('MachineLearning', () => { const stub = sinon .stub(MachineLearningApiClient.prototype, 'handleOperation') - .resolves(MODEL_RESPONSE2); + .resolves(TFLITE_GCS_MODEL_RESPONSE2); stubs.push(stub); model.waitForUnlocked() @@ -412,7 +575,7 @@ describe('MachineLearning', () => { }); it('should reject when API response does not contain a name', () => { - const response = deepCopy(MODEL_RESPONSE); + const response = deepCopy(TFLITE_GCS_MODEL_RESPONSE); response.name = ''; const stub = sinon .stub(MachineLearningApiClient.prototype, 'getModel') @@ -424,7 +587,7 @@ describe('MachineLearning', () => { }); it('should reject when API response does not contain a createTime', () => { - const response = deepCopy(MODEL_RESPONSE); + const response = deepCopy(TFLITE_GCS_MODEL_RESPONSE); response.createTime = ''; const stub = sinon .stub(MachineLearningApiClient.prototype, 'getModel') @@ -435,8 +598,8 @@ describe('MachineLearning', () => { 'message', `Invalid Model response: ${JSON.stringify(response)}`); }); - it('should reject when API response does not contain a updateTime', () => { - const response = deepCopy(MODEL_RESPONSE); + it('should reject when API response does not contain an updateTime', () => { + const response = deepCopy(TFLITE_GCS_MODEL_RESPONSE); response.updateTime = ''; const stub = sinon .stub(MachineLearningApiClient.prototype, 'getModel') @@ -448,7 +611,7 @@ describe('MachineLearning', () => { }); it('should reject when API response does not contain a displayName', () => { - const response = deepCopy(MODEL_RESPONSE); + const response = deepCopy(TFLITE_GCS_MODEL_RESPONSE); response.displayName = ''; const stub = sinon .stub(MachineLearningApiClient.prototype, 'getModel') @@ -460,7 +623,7 @@ describe('MachineLearning', () => { }); it('should reject when API response does not contain an etag', () => { - const response = deepCopy(MODEL_RESPONSE); + const response = deepCopy(TFLITE_GCS_MODEL_RESPONSE); response.etag = ''; const stub = sinon .stub(MachineLearningApiClient.prototype, 'getModel') @@ -474,7 +637,7 @@ describe('MachineLearning', () => { it('should resolve with Model on success', () => { const stub = sinon .stub(MachineLearningApiClient.prototype, 'getModel') - .resolves(MODEL_RESPONSE); + .resolves(TFLITE_GCS_MODEL_RESPONSE); stubs.push(stub); return machineLearning.getModel(MODEL_ID) @@ -488,8 +651,8 @@ describe('MachineLearning', () => { const LIST_MODELS_RESPONSE = { models: [ - MODEL_RESPONSE, - MODEL_RESPONSE2, + TFLITE_GCS_MODEL_RESPONSE, + TFLITE_GCS_MODEL_RESPONSE2, ], nextPageToken: 'next', }; @@ -582,7 +745,7 @@ describe('MachineLearning', () => { }); it('should reject when API response does not contain a name', () => { - const op = deepCopy(OPERATION_RESPONSE); + const op = deepCopy(TFLITE_GCS_OPERATION_RESPONSE); op.response!.name = ''; const stub = sinon .stub(MachineLearningApiClient.prototype, 'createModel') @@ -594,7 +757,7 @@ describe('MachineLearning', () => { }); it('should reject when API response does not contain a createTime', () => { - const op = deepCopy(OPERATION_RESPONSE); + const op = deepCopy(TFLITE_GCS_OPERATION_RESPONSE); op.response!.createTime = ''; const stub = sinon .stub(MachineLearningApiClient.prototype, 'createModel') @@ -606,7 +769,7 @@ describe('MachineLearning', () => { }); it('should reject when API response does not contain a updateTime', () => { - const op = deepCopy(OPERATION_RESPONSE); + const op = deepCopy(TFLITE_GCS_OPERATION_RESPONSE); op.response!.updateTime = ''; const stub = sinon .stub(MachineLearningApiClient.prototype, 'createModel') @@ -618,7 +781,7 @@ describe('MachineLearning', () => { }); it('should reject when API response does not contain a displayName', () => { - const op = deepCopy(OPERATION_RESPONSE); + const op = deepCopy(TFLITE_GCS_OPERATION_RESPONSE); op.response!.displayName = ''; const stub = sinon .stub(MachineLearningApiClient.prototype, 'createModel') @@ -630,7 +793,7 @@ describe('MachineLearning', () => { }); it('should reject when API response does not contain an etag', () => { - const op = deepCopy(OPERATION_RESPONSE); + const op = deepCopy(TFLITE_GCS_OPERATION_RESPONSE); op.response!.etag = ''; const stub = sinon .stub(MachineLearningApiClient.prototype, 'createModel') @@ -644,7 +807,7 @@ describe('MachineLearning', () => { it('should resolve with Model on success', () => { const stub = sinon .stub(MachineLearningApiClient.prototype, 'createModel') - .resolves(OPERATION_RESPONSE); + .resolves(TFLITE_GCS_OPERATION_RESPONSE); stubs.push(stub); return machineLearning.createModel(MODEL_OPTIONS_WITH_GCS) @@ -699,7 +862,7 @@ describe('MachineLearning', () => { }); it('should reject when API response does not contain a name', () => { - const op = deepCopy(OPERATION_RESPONSE); + const op = deepCopy(TFLITE_GCS_OPERATION_RESPONSE); op.response!.name = ''; const stub = sinon .stub(MachineLearningApiClient.prototype, 'updateModel') @@ -711,7 +874,7 @@ describe('MachineLearning', () => { }); it('should reject when API response does not contain a createTime', () => { - const op = deepCopy(OPERATION_RESPONSE); + const op = deepCopy(TFLITE_GCS_OPERATION_RESPONSE); op.response!.createTime = ''; const stub = sinon .stub(MachineLearningApiClient.prototype, 'updateModel') @@ -723,7 +886,7 @@ describe('MachineLearning', () => { }); it('should reject when API response does not contain a updateTime', () => { - const op = deepCopy(OPERATION_RESPONSE); + const op = deepCopy(TFLITE_GCS_OPERATION_RESPONSE); op.response!.updateTime = ''; const stub = sinon .stub(MachineLearningApiClient.prototype, 'updateModel') @@ -735,7 +898,7 @@ describe('MachineLearning', () => { }); it('should reject when API response does not contain a displayName', () => { - const op = deepCopy(OPERATION_RESPONSE); + const op = deepCopy(TFLITE_GCS_OPERATION_RESPONSE); op.response!.displayName = ''; const stub = sinon .stub(MachineLearningApiClient.prototype, 'updateModel') @@ -747,7 +910,7 @@ describe('MachineLearning', () => { }); it('should reject when API response does not contain an etag', () => { - const op = deepCopy(OPERATION_RESPONSE); + const op = deepCopy(TFLITE_GCS_OPERATION_RESPONSE); op.response!.etag = ''; const stub = sinon .stub(MachineLearningApiClient.prototype, 'updateModel') @@ -761,7 +924,7 @@ describe('MachineLearning', () => { it('should resolve with Model on success', () => { const stub = sinon .stub(MachineLearningApiClient.prototype, 'updateModel') - .resolves(OPERATION_RESPONSE); + .resolves(TFLITE_GCS_OPERATION_RESPONSE); stubs.push(stub); return machineLearning.updateModel(MODEL_ID, MODEL_OPTIONS_WITH_GCS) @@ -803,7 +966,7 @@ describe('MachineLearning', () => { }); it('should reject when API response does not contain a name', () => { - const op = deepCopy(OPERATION_RESPONSE); + const op = deepCopy(TFLITE_GCS_OPERATION_RESPONSE); op.response!.name = ''; const stub = sinon .stub(MachineLearningApiClient.prototype, 'updateModel') @@ -815,7 +978,7 @@ describe('MachineLearning', () => { }); it('should reject when API response does not contain a createTime', () => { - const op = deepCopy(OPERATION_RESPONSE); + const op = deepCopy(TFLITE_GCS_OPERATION_RESPONSE); op.response!.createTime = ''; const stub = sinon .stub(MachineLearningApiClient.prototype, 'updateModel') @@ -827,7 +990,7 @@ describe('MachineLearning', () => { }); it('should reject when API response does not contain a updateTime', () => { - const op = deepCopy(OPERATION_RESPONSE); + const op = deepCopy(TFLITE_GCS_OPERATION_RESPONSE); op.response!.updateTime = ''; const stub = sinon .stub(MachineLearningApiClient.prototype, 'updateModel') @@ -839,7 +1002,7 @@ describe('MachineLearning', () => { }); it('should reject when API response does not contain a displayName', () => { - const op = deepCopy(OPERATION_RESPONSE); + const op = deepCopy(TFLITE_GCS_OPERATION_RESPONSE); op.response!.displayName = ''; const stub = sinon .stub(MachineLearningApiClient.prototype, 'updateModel') @@ -851,7 +1014,7 @@ describe('MachineLearning', () => { }); it('should reject when API response does not contain an etag', () => { - const op = deepCopy(OPERATION_RESPONSE); + const op = deepCopy(TFLITE_GCS_OPERATION_RESPONSE); op.response!.etag = ''; const stub = sinon .stub(MachineLearningApiClient.prototype, 'updateModel') @@ -865,7 +1028,7 @@ describe('MachineLearning', () => { it('should resolve with Model on success', () => { const stub = sinon .stub(MachineLearningApiClient.prototype, 'updateModel') - .resolves(OPERATION_RESPONSE); + .resolves(TFLITE_GCS_OPERATION_RESPONSE); stubs.push(stub); return machineLearning.publishModel(MODEL_ID) @@ -907,7 +1070,7 @@ describe('MachineLearning', () => { }); it('should reject when API response does not contain a name', () => { - const op = deepCopy(OPERATION_RESPONSE); + const op = deepCopy(TFLITE_GCS_OPERATION_RESPONSE); op.response!.name = ''; const stub = sinon .stub(MachineLearningApiClient.prototype, 'updateModel') @@ -919,7 +1082,7 @@ describe('MachineLearning', () => { }); it('should reject when API response does not contain a createTime', () => { - const op = deepCopy(OPERATION_RESPONSE); + const op = deepCopy(TFLITE_GCS_OPERATION_RESPONSE); op.response!.createTime = ''; const stub = sinon .stub(MachineLearningApiClient.prototype, 'updateModel') @@ -931,7 +1094,7 @@ describe('MachineLearning', () => { }); it('should reject when API response does not contain a updateTime', () => { - const op = deepCopy(OPERATION_RESPONSE); + const op = deepCopy(TFLITE_GCS_OPERATION_RESPONSE); op.response!.updateTime = ''; const stub = sinon .stub(MachineLearningApiClient.prototype, 'updateModel') @@ -943,7 +1106,7 @@ describe('MachineLearning', () => { }); it('should reject when API response does not contain a displayName', () => { - const op = deepCopy(OPERATION_RESPONSE); + const op = deepCopy(TFLITE_GCS_OPERATION_RESPONSE); op.response!.displayName = ''; const stub = sinon .stub(MachineLearningApiClient.prototype, 'updateModel') @@ -955,7 +1118,7 @@ describe('MachineLearning', () => { }); it('should reject when API response does not contain an etag', () => { - const op = deepCopy(OPERATION_RESPONSE); + const op = deepCopy(TFLITE_GCS_OPERATION_RESPONSE); op.response!.etag = ''; const stub = sinon .stub(MachineLearningApiClient.prototype, 'updateModel') @@ -969,7 +1132,7 @@ describe('MachineLearning', () => { it('should resolve with Model on success', () => { const stub = sinon .stub(MachineLearningApiClient.prototype, 'updateModel') - .resolves(OPERATION_RESPONSE); + .resolves(TFLITE_GCS_OPERATION_RESPONSE); stubs.push(stub); return machineLearning.unpublishModel(MODEL_ID)