diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..272dbb2 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,31 @@ +name: Run Unit Tests + +on: + push: + branches: + - '**' + pull_request: + branches: + - '**' + +concurrency: + group: ${{ github.ref }} + cancel-in-progress: true + +jobs: + publish-npm: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + + - uses: actions/setup-node@v2 + with: + node-version: '18.x' + registry-url: 'https://registry.npmjs.org' + + - name: Install pnpm + run: npm install -g pnpm + - run: pnpm install + - run: pnpm run test + - run: pnpm run build + diff --git a/.github/workflows/e2e-test.yml b/.github/workflows/e2e-test.yml new file mode 100644 index 0000000..fc5b7f7 --- /dev/null +++ b/.github/workflows/e2e-test.yml @@ -0,0 +1,33 @@ +name: Run E2E Tests + +on: + pull_request: + branches: + - 'main' + +concurrency: + group: ${{ github.ref }}-e2e-tests + cancel-in-progress: true + +env: + CI: true + DEEPINFRA_API_KEY: ${{ secrets.DEEPINFRA_API_KEY }} + +jobs: + e2e-test: + if: false + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + + - uses: actions/setup-node@v2 + with: + node-version: '18.x' + registry-url: 'https://registry.npmjs.org' + + - name: Install pnpm + run: npm install -g pnpm + - run: pnpm install + - run: pnpm run test:e2e + - run: pnpm run build + diff --git a/.gitignore b/.gitignore index dc563b7..fab3dfb 100644 --- a/.gitignore +++ b/.gitignore @@ -144,3 +144,4 @@ docs dist **misc.ts docs +.husky diff --git a/.husky/pre-commit b/.husky/pre-commit deleted file mode 100644 index 605953f..0000000 --- a/.husky/pre-commit +++ /dev/null @@ -1 +0,0 @@ -. "$(dirname -- "$0")/_/husky.sh" diff --git a/.husky/prepare-commit-msg b/.husky/prepare-commit-msg deleted file mode 100644 index 339d9fd..0000000 --- a/.husky/prepare-commit-msg +++ /dev/null @@ -1,6 +0,0 @@ -npm run lint -npm run prettier -npm run build -npm run test - -exec < /dev/tty && git cz --hook || true diff --git a/jest.e2e.config.js b/jest.e2e.config.js new file mode 100644 index 0000000..ab80b08 --- /dev/null +++ b/jest.e2e.config.js @@ -0,0 +1,22 @@ +/** @type {import('ts-jest').JstConfigWithTsJest} */ +module.exports = { + testEnvironment: 'node', + extensionsToTreatAsEsm: [".ts"], + testTimeout: 10000, + coveragePathIgnorePatterns: [ + "/node_modules/", + "/examples/", + "/test/" + ], + moduleNameMapper: { + "^@/(.*)$": "/src/$1" + }, + testMatch: [ + "/test/**/*.e2e-test.ts" + ], + rootDir: ".", + transform: { + "^.+\\.tsx?$": "ts-jest" + }, + +}; diff --git a/package.json b/package.json index 27647ce..e85b469 100644 --- a/package.json +++ b/package.json @@ -12,6 +12,7 @@ "misc": "npx ts-node -r tsconfig-paths/register src/misc.ts", "prepare": "husky", "test": "jest --passWithNoTests", + "test:e2e": "jest --config=jest.e2e.config.js --passWithNoTests --runInBand", "lint": "eslint . --ext .ts --fix", "prettier": "prettier --write ./src ./test", "build-docs": "typedoc --out docs src", @@ -40,7 +41,8 @@ "dependencies": { "@swc/core": "^1.4.6", "@swc/wasm": "^1.4.6", - "axios": "^1.6.7" + "axios": "^1.6.7", + "p-limit": "^5.0.0" }, "devDependencies": { "@types/jest": "^29.5.12", diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index eae292b..8a7ec9d 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -14,6 +14,9 @@ dependencies: axios: specifier: ^1.6.7 version: 1.6.7 + p-limit: + specifier: ^5.0.0 + version: 5.0.0 devDependencies: '@types/jest': @@ -3411,6 +3414,13 @@ packages: yocto-queue: 0.1.0 dev: true + /p-limit@5.0.0: + resolution: {integrity: sha512-/Eaoq+QyLSiXQ4lyYV23f14mZRQcXnxfHrN0vCai+ak9G0pp9iEQukIIZq5NccEvwRB8PUnZT0KsOoDCINS1qQ==} + engines: {node: '>=18'} + dependencies: + yocto-queue: 1.0.0 + dev: false + /p-locate@4.1.0: resolution: {integrity: sha512-R79ZZ/0wAxKGu3oYMlz8jy/kbhsNrS7SKZ7PxEHBgJ5+F2mtFW2fK2cOtBh1cHYkQsbzFV7I+EoRKe6Yt0oK7A==} engines: {node: '>=8'} @@ -4238,3 +4248,8 @@ packages: resolution: {integrity: sha512-rVksvsnNCdJ/ohGc6xgPwyN8eheCxsiLM8mxuE/t/mOVqJewPuO1miLpTHQiRgTKCLexL4MeAFVagts7HmNZ2Q==} engines: {node: '>=10'} dev: true + + /yocto-queue@1.0.0: + resolution: {integrity: sha512-9bnSc/HEW2uRy67wc+T8UwauLuPJVn28jb+GtJY16iiKWyvmYJRXVT4UamsAEGQfPohgr2q4Tq0sQbQlxTfi1g==} + engines: {node: '>=12.20'} + dev: false diff --git a/src/lib/types/common/models.ts b/src/lib/types/common/models.ts new file mode 100644 index 0000000..5e2e187 --- /dev/null +++ b/src/lib/types/common/models.ts @@ -0,0 +1,33 @@ +export const enum ModelTypes { + EMBEDDINGS = 'embeddings', + FILL_MASK = 'fill-mask', + TEXT_GENERATION = 'text-generation', + AUTOMATIC_SPEECH_RECOGNITION = 'automatic-speech-recognition', + TOKEN_CLASSIFICATION = 'token-classification', + TEXT2TEXT_GENERATION = 'text2text-generation', + OBJECT_DETECTION = 'object-detection', + QUESTION_ANSWERING = 'question-answering', + IMAGE_CLASSIFICATION = 'image-classification', + TEXT_TO_IMAGE = 'text-to-image', + ZERO_SHOT_IMAGE_CLASSIFICATION = 'zero-shot-image-classification', + CUSTOM = 'custom', + TEXT_CLASSIFICATION = 'text-classification', + DREAMBOOTH = 'dreambooth', +} + + +export interface ModelDefinition { + model_name: string; + type: ModelTypes; + reported_type: ModelTypes; + description: string; + cover_img_url: string; + tags: string[]; + pricing: { + cents_per_sec: number; + type: string; + }; + max_tokens: number | null; +} + +export type ModelDefinitionList = ModelDefinition[]; diff --git a/test/e2e/index.e2e-test.ts b/test/e2e/index.e2e-test.ts new file mode 100644 index 0000000..c2a913d --- /dev/null +++ b/test/e2e/index.e2e-test.ts @@ -0,0 +1,174 @@ +import axios from 'axios'; +import {ModelDefinition, ModelDefinitionList, ModelTypes} from "@/lib/types/common/models"; +import { + AutomaticSpeechRecognition, + Embeddings, + FillMask, + ObjectDetection, + QuestionAnswering, + Sdxl, + TextClassification, + TextGeneration, + TextToImage, + TokenClassification +} from "@/index"; + +const GET_MODELS = "https://api.deepinfra.com/models/list"; +const SDXL_MODEL = "stability-ai/sdxl"; +const TEXT_TO_IMAGE_PROMPT = "The quick brown fox jumps over the lazy dog."; +const TEXT_INPUT = "This is a test."; + +/* +TODO: Add mock audio file for ASR. +TODO: Add mock image file for object detection. +TODO: Implement p-limit + */ + + +describe('E2E tests', () => { + + let allModels: ModelDefinitionList; + + beforeAll(async () => { + const response = await axios.get(GET_MODELS); + allModels = response.data as ModelDefinitionList; + }); + + it('should have at least one model', () => { + expect(allModels.length).toBeGreaterThan(0); + }); + it('Text to image models should infer correctly.', () => { + const textToImageModels = allModels.filter(model => model.reported_type === ModelTypes.TEXT_TO_IMAGE).map(m => m.model_name); + textToImageModels.forEach(async (modelName) => { + + if (modelName === SDXL_MODEL) { + const model = new Sdxl(); + expect(model).toBeDefined(); + await model.generate({input: {prompt: TEXT_TO_IMAGE_PROMPT}}).then(response => { + expect(response).toBeDefined(); + }); + } else { + const model = new TextToImage(modelName); + expect(model).toBeDefined(); + + await model.generate({prompt: TEXT_TO_IMAGE_PROMPT}).then(response => { + expect(response).toBeDefined(); + }); + } + + }); + }); + + it('Text classification models should infer correctly.', () => { + const textClassificationModels = allModels.filter(model => model.reported_type === ModelTypes.TEXT_CLASSIFICATION).map(m => m.model_name); + textClassificationModels.forEach(async (modelName) => { + const model = new TextClassification(modelName); + expect(model).toBeDefined(); + + await model.generate({input: TEXT_INPUT}).then(response => { + expect(response).toBeDefined(); + }); + }); + }); + + it('Text generation models should infer correctly.', () => { + const textGenerationModels = allModels.filter(model => model.reported_type === ModelTypes.TEXT_GENERATION).map(m => m.model_name); + textGenerationModels.forEach(async (modelName) => { + const model = new TextGeneration(modelName); + expect(model).toBeDefined(); + + await model.generate({input: TEXT_INPUT}).then(response => { + expect(response).toBeDefined(); + }); + }); + }); + + it('Fill mask models should infer correctly.', () => { + const fillMaskModels = allModels.filter(model => model.reported_type === ModelTypes.FILL_MASK).map(m => m.model_name); + fillMaskModels.forEach(async (modelName) => { + const model = new FillMask(modelName); + expect(model).toBeDefined(); + + await model.generate({input: "This is a [MASK]"}) + .then(response => { + expect(response).toBeDefined(); + }); + }); + }); + + it('Embeddings models should infer correctly.', () => { + const embeddingsModels = allModels.filter(model => model.reported_type === ModelTypes.EMBEDDINGS).map(m => m.model_name); + embeddingsModels.forEach(async (modelName) => { + const model = new Embeddings(modelName); + expect(model).toBeDefined(); + + await model.generate({inputs: [TEXT_INPUT]}).then(response => { + expect(response).toBeDefined(); + }); + }); + }); + + + + it('Question answering models should infer correctly.', () => { + const questionAnsweringModels = allModels.filter(model => model.reported_type === ModelTypes.QUESTION_ANSWERING).map(m => m.model_name); + questionAnsweringModels.forEach(async (modelName) => { + const model = new QuestionAnswering(modelName); + expect(model).toBeDefined(); + + await model.generate({question: TEXT_INPUT, context: TEXT_INPUT}) + .then(response => { + expect(response).toBeDefined(); + }); + }); + }); + + it('Token classification models should infer correctly.', () => { + const tokenClassificationModels = allModels.filter(model => model.reported_type === ModelTypes.TOKEN_CLASSIFICATION).map(m => m.model_name); + tokenClassificationModels.forEach(async (modelName) => { + const model = new TokenClassification(modelName); + expect(model).toBeDefined(); + + await model.generate({input: TEXT_INPUT}).then(response => { + expect(response).toBeDefined(); + }); + }); + }); + + it('Text2Text generation models should infer correctly.', () => { + const text2TextGenerationModels = allModels.filter(model => model.reported_type === ModelTypes.TEXT2TEXT_GENERATION).map(m => m.model_name); + text2TextGenerationModels.forEach(async (modelName) => { + const model = new TextGeneration(modelName); + expect(model).toBeDefined(); + + await model.generate({input: TEXT_INPUT}).then(response => { + expect(response).toBeDefined(); + }); + }); + }); + + it('Object detection models should infer correctly.', () => { + const objectDetectionModels = allModels.filter(model => model.reported_type === ModelTypes.OBJECT_DETECTION).map(m => m.model_name); + objectDetectionModels.forEach(async (modelName) => { + const model = new ObjectDetection(modelName); + expect(model).toBeDefined(); + + await model.generate({input: TEXT_INPUT}).then(response => { + expect(response).toBeDefined(); + }); + }); + }); + + it('Automatic speech recognition models should infer correctly.', () => { + const automaticSpeechRecognitionModels = allModels.filter(model => model.reported_type === ModelTypes.AUTOMATIC_SPEECH_RECOGNITION).map(m => m.model_name); + automaticSpeechRecognitionModels.forEach(async (modelName) => { + const model = new AutomaticSpeechRecognition(modelName); + expect(model).toBeDefined(); + + await model.generate({input: TEXT_INPUT}).then(response => { + expect(response).toBeDefined(); + }); + }); + }); + +});