From cbb1950135fb5c16a2c3068469e28e2a063bf900 Mon Sep 17 00:00:00 2001 From: Wu Clan Date: Fri, 5 Dec 2025 17:25:59 +0800 Subject: [PATCH 1/4] Add ai plugin --- README.md | 8 ++------ __init__.py | 0 api/router.py | 8 ++++++++ api/v1/__init__.py | 0 api/v1/chat.py | 8 ++++++++ crud/__init__.py | 0 model/__init__.py | 0 plugin.toml | 8 ++++++++ requirements.txt | 1 + schema/__init__.py | 0 service/__init__.py | 0 sql/mysql/init.sql | 0 sql/mysql/init_snowflake.sql | 0 sql/postgrsql/init.sql | 0 sql/postgrsql/init_snowflake.sql | 0 15 files changed, 27 insertions(+), 6 deletions(-) create mode 100644 __init__.py create mode 100644 api/router.py create mode 100644 api/v1/__init__.py create mode 100644 api/v1/chat.py create mode 100644 crud/__init__.py create mode 100644 model/__init__.py create mode 100644 requirements.txt create mode 100644 schema/__init__.py create mode 100644 service/__init__.py create mode 100644 sql/mysql/init.sql create mode 100644 sql/mysql/init_snowflake.sql create mode 100644 sql/postgrsql/init.sql create mode 100644 sql/postgrsql/init_snowflake.sql diff --git a/README.md b/README.md index 52c6662..dfccc62 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,3 @@ -请在此处填写插件使用说明和您的联系方式 +## AI -如果插件需要付费,请提供付费相关说明 - -如有配套前端插件,请添加前端插件仓库链接说明 - -插件开发文档:[fba plugin dev](https://fastapi-practices.github.io/fastapi_best_architecture_docs/plugin/dev.html) +此插件提供了 AI 能力 diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/api/router.py b/api/router.py new file mode 100644 index 0000000..d7ad691 --- /dev/null +++ b/api/router.py @@ -0,0 +1,8 @@ +from fastapi import APIRouter + +from backend.core.conf import settings +from backend.plugin.ai.api.v1.chat import router as chat_router + +v1 = APIRouter(prefix=settings.FASTAPI_API_V1_PATH) + +v1.include_router(chat_router, prefix='/chat', tags=['AI 文本生成']) diff --git a/api/v1/__init__.py b/api/v1/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/api/v1/chat.py b/api/v1/chat.py new file mode 100644 index 0000000..128b456 --- /dev/null +++ b/api/v1/chat.py @@ -0,0 +1,8 @@ +from fastapi import APIRouter +from starlette.responses import StreamingResponse + +router = APIRouter() + + +@router.post('/completions', summary='文本生成(对话)') +async def completions() -> StreamingResponse: ... diff --git a/crud/__init__.py b/crud/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/model/__init__.py b/model/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/plugin.toml b/plugin.toml index e69de29..3b41936 100644 --- a/plugin.toml +++ b/plugin.toml @@ -0,0 +1,8 @@ +[plugin] +summary = 'AI 工具' +version = '0.0.1' +description = '为系统提供 AI 赋能' +author = 'wu-clan' + +[app] +router = ['v1'] diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..a06c42d --- /dev/null +++ b/requirements.txt @@ -0,0 +1 @@ +pydantic-ai-slim[openai,google,anthropic,groq,mcp]>=1.22.0 diff --git a/schema/__init__.py b/schema/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/service/__init__.py b/service/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/sql/mysql/init.sql b/sql/mysql/init.sql new file mode 100644 index 0000000..e69de29 diff --git a/sql/mysql/init_snowflake.sql b/sql/mysql/init_snowflake.sql new file mode 100644 index 0000000..e69de29 diff --git a/sql/postgrsql/init.sql b/sql/postgrsql/init.sql new file mode 100644 index 0000000..e69de29 diff --git a/sql/postgrsql/init_snowflake.sql b/sql/postgrsql/init_snowflake.sql new file mode 100644 index 0000000..e69de29 From 8fda59e287e504dae711a1a800ed82b4dd35b4c9 Mon Sep 17 00:00:00 2001 From: Wu Clan Date: Tue, 16 Dec 2025 16:52:11 +0800 Subject: [PATCH 2/4] Fix the pgsql script path --- sql/{postgrsql => postgresql}/init.sql | 0 sql/{postgrsql => postgresql}/init_snowflake.sql | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename sql/{postgrsql => postgresql}/init.sql (100%) rename sql/{postgrsql => postgresql}/init_snowflake.sql (100%) diff --git a/sql/postgrsql/init.sql b/sql/postgresql/init.sql similarity index 100% rename from sql/postgrsql/init.sql rename to sql/postgresql/init.sql diff --git a/sql/postgrsql/init_snowflake.sql b/sql/postgresql/init_snowflake.sql similarity index 100% rename from sql/postgrsql/init_snowflake.sql rename to sql/postgresql/init_snowflake.sql From e5da6965528ad80aff82a40e94c6f22da2fbd8a0 Mon Sep 17 00:00:00 2001 From: Wu Clan Date: Thu, 18 Dec 2025 15:25:27 +0800 Subject: [PATCH 3/4] Add provider management --- api/v1/provider.py | 85 ++++++++++++++++++++++++++++++++++++ crud/crud_provider.py | 67 ++++++++++++++++++++++++++++ model/__init__.py | 1 + model/provider.py | 19 ++++++++ requirements.txt | 2 +- schema/provider.py | 41 +++++++++++++++++ service/provider_service.py | 87 +++++++++++++++++++++++++++++++++++++ 7 files changed, 301 insertions(+), 1 deletion(-) create mode 100644 api/v1/provider.py create mode 100644 crud/crud_provider.py create mode 100644 model/provider.py create mode 100644 schema/provider.py create mode 100644 service/provider_service.py diff --git a/api/v1/provider.py b/api/v1/provider.py new file mode 100644 index 0000000..6a41c1c --- /dev/null +++ b/api/v1/provider.py @@ -0,0 +1,85 @@ +from typing import Annotated + +from fastapi import APIRouter, Depends, Path + +from backend.common.pagination import DependsPagination, PageData +from backend.common.response.response_schema import ResponseModel, ResponseSchemaModel, response_base +from backend.common.security.jwt import DependsJwtAuth +from backend.common.security.permission import RequestPermission +from backend.common.security.rbac import DependsRBAC +from backend.database.db import CurrentSession, CurrentSessionTransaction +from backend.plugin.ai.schema.provider import ( + CreateAiProviderParam, + DeleteAiProviderParam, + GetAiProviderDetail, + UpdateAiProviderParam, +) +from backend.plugin.ai.service.provider_service import ai_provider_service + +router = APIRouter() + + +@router.get('/{pk}', summary='获取供应商详情', dependencies=[DependsJwtAuth]) +async def get_ai_provider( + db: CurrentSession, pk: Annotated[int, Path(description='provider ID')] +) -> ResponseSchemaModel[GetAiProviderDetail]: + ai_provider = await ai_provider_service.get(db=db, pk=pk) + return response_base.success(data=ai_provider) + + +@router.get( + '', + summary='分页获取所有供应商', + dependencies=[ + DependsJwtAuth, + DependsPagination, + ], +) +async def get_ai_providers_paginated(db: CurrentSession) -> ResponseSchemaModel[PageData[GetAiProviderDetail]]: + page_data = await ai_provider_service.get_list(db=db) + return response_base.success(data=page_data) + + +@router.post( + '', + summary='创建provider', + dependencies=[ + Depends(RequestPermission('ai:provider:add')), + DependsRBAC, + ], +) +async def create_ai_provider(db: CurrentSessionTransaction, obj: CreateAiProviderParam) -> ResponseModel: + await ai_provider_service.create(db=db, obj=obj) + return response_base.success() + + +@router.put( + '/{pk}', + summary='更新供应商', + dependencies=[ + Depends(RequestPermission('ai:provider:edit')), + DependsRBAC, + ], +) +async def update_ai_provider( + db: CurrentSessionTransaction, pk: Annotated[int, Path(description='供应商 ID')], obj: UpdateAiProviderParam +) -> ResponseModel: + count = await ai_provider_service.update(db=db, pk=pk, obj=obj) + if count > 0: + return response_base.success() + return response_base.fail() + + +@router.delete( + '', + summary='批量删除供应商', + dependencies=[ + Depends(RequestPermission('ai:provider:del')), + DependsRBAC, + ], +) +async def delete_ai_providers(db: CurrentSessionTransaction, obj: DeleteAiProviderParam) -> ResponseModel: + count = await ai_provider_service.delete(db=db, obj=obj) + if count > 0: + return response_base.success() + return response_base.fail() diff --git a/crud/crud_provider.py b/crud/crud_provider.py new file mode 100644 index 0000000..3b4cd3d --- /dev/null +++ b/crud/crud_provider.py @@ -0,0 +1,67 @@ +from collections.abc import Sequence + +from sqlalchemy import Select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy_crud_plus import CRUDPlus + +from backend.plugin.ai.model import AiProvider +from backend.plugin.ai.schema.provider import CreateAiProviderParam, UpdateAiProviderParam + + +class CRUDAiProvider(CRUDPlus[AiProvider]): + async def get(self, db: AsyncSession, pk: int) -> AiProvider | None: + """ + 获取供应商 + + :param db: 数据库会话 + :param pk: 供应商 ID + :return: + """ + return await self.select_model(db, pk) + + async def get_select(self) -> Select: + """获取供应商列表查询表达式""" + return await self.select_order('id', 'desc') + + async def get_all(self, db: AsyncSession) -> Sequence[AiProvider]: + """ + 获取所有供应商 + + :param db: 数据库会话 + :return: + """ + return await self.select_models(db) + + async def create(self, db: AsyncSession, obj: CreateAiProviderParam) -> None: + """ + 创建供应商 + + :param db: 数据库会话 + :param obj: 创建供应商参数 + :return: + """ + await self.create_model(db, obj) + + async def update(self, db: AsyncSession, pk: int, obj: UpdateAiProviderParam) -> int: + """ + 更新供应商 + + :param db: 数据库会话 + :param pk: 供应商 ID + :param obj: 更新 供应商参数 + :return: + """ + return await self.update_model(db, pk, obj) + + async def delete(self, db: AsyncSession, pks: list[int]) -> int: + """ + 批量删除供应商 + + :param db: 数据库会话 + :param pks: 供应商 ID 列表 + :return: + """ + return await self.delete_model_by_column(db, allow_multiple=True, id__in=pks) + + +ai_provider_dao: CRUDAiProvider = CRUDAiProvider(AiProvider) diff --git a/model/__init__.py b/model/__init__.py index e69de29..e78bca8 100644 --- a/model/__init__.py +++ b/model/__init__.py @@ -0,0 +1 @@ +from backend.plugin.ai.model.provider import AiProvider as AiProvider diff --git a/model/provider.py b/model/provider.py new file mode 100644 index 0000000..5cf2aab --- /dev/null +++ b/model/provider.py @@ -0,0 +1,19 @@ +import sqlalchemy as sa + +from sqlalchemy.orm import Mapped, mapped_column + +from backend.common.model import Base, UniversalText, id_key + + +class AiProvider(Base): + """AI 供应商""" + + __tablename__ = 'ai_provider' + + id: Mapped[id_key] = mapped_column(init=False) + name: Mapped[str] = mapped_column(sa.String(256), comment='供应商名称') + type: Mapped[int] = mapped_column(comment='供应商类型(0OpenAI 1Anthropic 2Gemini)') + api_key: Mapped[str] = mapped_column(UniversalText, comment='API Key') + api_host: Mapped[str] = mapped_column(sa.String(512), comment='API Host') + status: Mapped[int] = mapped_column(default=1, comment='角色状态(0停用 1正常)') + remark: Mapped[str | None] = mapped_column(UniversalText, default=None, comment='备注') diff --git a/requirements.txt b/requirements.txt index a06c42d..a5aacfb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1 @@ -pydantic-ai-slim[openai,google,anthropic,groq,mcp]>=1.22.0 +pydantic-ai-slim[openai,google,anthropic,groq,mcp]>=1.34.0 diff --git a/schema/provider.py b/schema/provider.py new file mode 100644 index 0000000..c06b9a3 --- /dev/null +++ b/schema/provider.py @@ -0,0 +1,41 @@ +from datetime import datetime + +from pydantic import ConfigDict, Field + +from backend.common.enums import StatusType +from backend.common.schema import SchemaBase + + +class AiProviderSchemaBase(SchemaBase): + """供应商基础模型""" + + name: str = Field(description='供应商名称') + type: int = Field(description='供应商类型(0OpenAI 1Anthropic 2Gemini)') + api_key: str = Field(description='API Key') + api_host: str = Field(description='API Host') + status: StatusType = Field(description='状态') + remark: str | None = Field(None, description='备注') + + +class CreateAiProviderParam(AiProviderSchemaBase): + """创建供应商参数""" + + +class UpdateAiProviderParam(AiProviderSchemaBase): + """更新供应商参数""" + + +class DeleteAiProviderParam(SchemaBase): + """删除供应商参数""" + + pks: list[int] = Field(description='供应商 ID 列表') + + +class GetAiProviderDetail(AiProviderSchemaBase): + """供应商详情""" + + model_config = ConfigDict(from_attributes=True) + + id: int + created_time: datetime + updated_time: datetime | None = None diff --git a/service/provider_service.py b/service/provider_service.py new file mode 100644 index 0000000..f5b4a34 --- /dev/null +++ b/service/provider_service.py @@ -0,0 +1,87 @@ +from collections.abc import Sequence +from typing import Any + +from sqlalchemy.ext.asyncio import AsyncSession + +from backend.common.exception import errors +from backend.common.pagination import paging_data +from backend.plugin.ai.crud.crud_provider import ai_provider_dao +from backend.plugin.ai.model import AiProvider +from backend.plugin.ai.schema.provider import CreateAiProviderParam, DeleteAiProviderParam, UpdateAiProviderParam + + +class AiProviderService: + @staticmethod + async def get(*, db: AsyncSession, pk: int) -> AiProvider: + """ + 获取供应商 + + :param db: 数据库会话 + :param pk: 供应商 ID + :return: + """ + ai_provider = await ai_provider_dao.get(db, pk) + if not ai_provider: + raise errors.NotFoundError(msg='供应商不存在') + return ai_provider + + @staticmethod + async def get_list(db: AsyncSession) -> dict[str, Any]: + """ + 获取供应商列表 + + :param db: 数据库会话 + :return: + """ + ai_provider_select = await ai_provider_dao.get_select() + return await paging_data(db, ai_provider_select) + + @staticmethod + async def get_all(*, db: AsyncSession) -> Sequence[AiProvider]: + """ + 获取所有供应商 + + :param db: 数据库会话 + :return: + """ + ai_providers = await ai_provider_dao.get_all(db) + return ai_providers + + @staticmethod + async def create(*, db: AsyncSession, obj: CreateAiProviderParam) -> None: + """ + 创建供应商 + + :param db: 数据库会话 + :param obj: 创建供应商参数 + :return: + """ + await ai_provider_dao.create(db, obj) + + @staticmethod + async def update(*, db: AsyncSession, pk: int, obj: UpdateAiProviderParam) -> int: + """ + 更新供应商 + + :param db: 数据库会话 + :param pk: 供应商 ID + :param obj: 更新供应商参数 + :return: + """ + count = await ai_provider_dao.update(db, pk, obj) + return count + + @staticmethod + async def delete(*, db: AsyncSession, obj: DeleteAiProviderParam) -> int: + """ + 删除供应商 + + :param db: 数据库会话 + :param obj: 供应商 ID 列表 + :return: + """ + count = await ai_provider_dao.delete(db, obj.pks) + return count + + +ai_provider_service: AiProviderService = AiProviderService() From 7eff13384a57292f29dc4053c3dabeff11cdf855 Mon Sep 17 00:00:00 2001 From: Wu Clan Date: Wed, 24 Dec 2025 11:57:48 +0800 Subject: [PATCH 4/4] Add model manage and other codes --- api/router.py | 4 ++ api/v1/chat.py | 7 +- api/v1/model.py | 85 ++++++++++++++++++++++++ api/v1/provider.py | 39 ++++++++--- crud/crud_model.py | 99 ++++++++++++++++++++++++++++ crud/crud_provider.py | 16 ++--- enums.py | 17 +++++ model/__init__.py | 3 +- model/model.py | 18 +++++ model/provider.py | 6 +- requirements.txt | 2 +- schema/chat.py | 28 ++++++++ schema/model.py | 40 ++++++++++++ schema/provider.py | 32 +++++---- service/chat_service.py | 81 +++++++++++++++++++++++ service/model_service.py | 85 ++++++++++++++++++++++++ service/provider_service.py | 89 ++++++++++++++++++++----- utils/__init__.py | 0 utils/message_parse.py | 32 +++++++++ utils/model_control.py | 127 ++++++++++++++++++++++++++++++++++++ 20 files changed, 759 insertions(+), 51 deletions(-) create mode 100644 api/v1/model.py create mode 100644 crud/crud_model.py create mode 100644 enums.py create mode 100644 model/model.py create mode 100644 schema/chat.py create mode 100644 schema/model.py create mode 100644 service/chat_service.py create mode 100644 service/model_service.py create mode 100644 utils/__init__.py create mode 100644 utils/message_parse.py create mode 100644 utils/model_control.py diff --git a/api/router.py b/api/router.py index d7ad691..a186a59 100644 --- a/api/router.py +++ b/api/router.py @@ -2,7 +2,11 @@ from backend.core.conf import settings from backend.plugin.ai.api.v1.chat import router as chat_router +from backend.plugin.ai.api.v1.model import router as model_router +from backend.plugin.ai.api.v1.provider import router as provider_router v1 = APIRouter(prefix=settings.FASTAPI_API_V1_PATH) v1.include_router(chat_router, prefix='/chat', tags=['AI 文本生成']) +v1.include_router(model_router, prefix='/models', tags=['AI 模型管理']) +v1.include_router(provider_router, prefix='/providers', tags=['AI 供应商管理']) diff --git a/api/v1/chat.py b/api/v1/chat.py index 128b456..7816c28 100644 --- a/api/v1/chat.py +++ b/api/v1/chat.py @@ -1,8 +1,13 @@ from fastapi import APIRouter from starlette.responses import StreamingResponse +from backend.database.db import CurrentSession +from backend.plugin.ai.schema.chat import AIChat +from backend.plugin.ai.service.chat_service import ai_chat_service + router = APIRouter() @router.post('/completions', summary='文本生成(对话)') -async def completions() -> StreamingResponse: ... +async def completions(db: CurrentSession, chat: AIChat) -> StreamingResponse: + return StreamingResponse(ai_chat_service.stream_messages(db=db, chat=chat), media_type='text/plain') diff --git a/api/v1/model.py b/api/v1/model.py new file mode 100644 index 0000000..74c783c --- /dev/null +++ b/api/v1/model.py @@ -0,0 +1,85 @@ +from typing import Annotated + +from fastapi import APIRouter, Depends, Path + +from backend.common.pagination import DependsPagination, PageData +from backend.common.response.response_schema import ResponseModel, ResponseSchemaModel, response_base +from backend.common.security.jwt import DependsJwtAuth +from backend.common.security.permission import RequestPermission +from backend.common.security.rbac import DependsRBAC +from backend.database.db import CurrentSession, CurrentSessionTransaction +from backend.plugin.ai.schema.model import ( + CreateAIModelParam, + DeleteAIModelParam, + GetAIModelDetail, + UpdateAIModelParam, +) +from backend.plugin.ai.service.model_service import ai_model_service + +router = APIRouter() + + +@router.get('/{pk}', summary='获取模型详情', dependencies=[DependsJwtAuth]) +async def get_ai_model( + db: CurrentSession, pk: Annotated[int, Path(description='模型 ID')] +) -> ResponseSchemaModel[GetAIModelDetail]: + ai_model = await ai_model_service.get(db=db, pk=pk) + return response_base.success(data=ai_model) + + +@router.get( + '', + summary='分页获取所有模型', + dependencies=[ + DependsJwtAuth, + DependsPagination, + ], +) +async def get_ai_models_paginated(db: CurrentSession) -> ResponseSchemaModel[PageData[GetAIModelDetail]]: + page_data = await ai_model_service.get_list(db=db) + return response_base.success(data=page_data) + + +@router.post( + '', + summary='创建模型', + dependencies=[ + Depends(RequestPermission('ai:model:add')), + DependsRBAC, + ], +) +async def create_ai_model(db: CurrentSessionTransaction, obj: CreateAIModelParam) -> ResponseModel: + await ai_model_service.create(db=db, obj=obj) + return response_base.success() + + +@router.put( + '/{pk}', + summary='更新模型', + dependencies=[ + Depends(RequestPermission('ai:model:edit')), + DependsRBAC, + ], +) +async def update_ai_model( + db: CurrentSessionTransaction, pk: Annotated[int, Path(description='模型 ID')], obj: UpdateAIModelParam +) -> ResponseModel: + count = await ai_model_service.update(db=db, pk=pk, obj=obj) + if count > 0: + return response_base.success() + return response_base.fail() + + +@router.delete( + '', + summary='批量删除模型', + dependencies=[ + Depends(RequestPermission('ai:model:del')), + DependsRBAC, + ], +) +async def delete_ai_models(db: CurrentSessionTransaction, obj: DeleteAIModelParam) -> ResponseModel: + count = await ai_model_service.delete(db=db, obj=obj) + if count > 0: + return response_base.success() + return response_base.fail() diff --git a/api/v1/provider.py b/api/v1/provider.py index 6a41c1c..45d5797 100644 --- a/api/v1/provider.py +++ b/api/v1/provider.py @@ -9,10 +9,11 @@ from backend.common.security.rbac import DependsRBAC from backend.database.db import CurrentSession, CurrentSessionTransaction from backend.plugin.ai.schema.provider import ( - CreateAiProviderParam, - DeleteAiProviderParam, - GetAiProviderDetail, - UpdateAiProviderParam, + CreateAIProviderParam, + DeleteAIProviderParam, + GetAIProviderDetail, + GetAIProviderModelDetail, + UpdateAIProviderParam, ) from backend.plugin.ai.service.provider_service import ai_provider_service @@ -22,11 +23,29 @@ @router.get('/{pk}', summary='获取供应商详情', dependencies=[DependsJwtAuth]) async def get_ai_provider( db: CurrentSession, pk: Annotated[int, Path(description='provider ID')] -) -> ResponseSchemaModel[GetAiProviderDetail]: +) -> ResponseSchemaModel[GetAIProviderDetail]: ai_provider = await ai_provider_service.get(db=db, pk=pk) return response_base.success(data=ai_provider) +@router.get('/{pk}/models', summary='获取供应商模型列表', dependencies=[DependsJwtAuth]) +async def get_ai_provider_models( + db: CurrentSession, + pk: Annotated[int, Path(description='provider ID')], +) -> ResponseSchemaModel[list[GetAIProviderModelDetail]]: + ai_provider_modes = await ai_provider_service.get_models(db=db, pk=pk) + return response_base.success(data=ai_provider_modes) + + +@router.get('/{pk}/models/sync', summary='同步供应商模型', dependencies=[DependsJwtAuth]) +async def sync_ai_provider_models( + db: CurrentSessionTransaction, + pk: Annotated[int, Path(description='provider ID')], +) -> ResponseModel: + await ai_provider_service.sync_models(db=db, pk=pk) + return response_base.success() + + @router.get( '', summary='分页获取所有供应商', @@ -35,20 +54,20 @@ async def get_ai_provider( DependsPagination, ], ) -async def get_ai_providers_paginated(db: CurrentSession) -> ResponseSchemaModel[PageData[GetAiProviderDetail]]: +async def get_ai_providers_paginated(db: CurrentSession) -> ResponseSchemaModel[PageData[GetAIProviderDetail]]: page_data = await ai_provider_service.get_list(db=db) return response_base.success(data=page_data) @router.post( '', - summary='创建provider', + summary='创建供应商', dependencies=[ Depends(RequestPermission('ai:provider:add')), DependsRBAC, ], ) -async def create_ai_provider(db: CurrentSessionTransaction, obj: CreateAiProviderParam) -> ResponseModel: +async def create_ai_provider(db: CurrentSessionTransaction, obj: CreateAIProviderParam) -> ResponseModel: await ai_provider_service.create(db=db, obj=obj) return response_base.success() @@ -62,7 +81,7 @@ async def create_ai_provider(db: CurrentSessionTransaction, obj: CreateAiProvide ], ) async def update_ai_provider( - db: CurrentSessionTransaction, pk: Annotated[int, Path(description='供应商 ID')], obj: UpdateAiProviderParam + db: CurrentSessionTransaction, pk: Annotated[int, Path(description='供应商 ID')], obj: UpdateAIProviderParam ) -> ResponseModel: count = await ai_provider_service.update(db=db, pk=pk, obj=obj) if count > 0: @@ -78,7 +97,7 @@ async def update_ai_provider( DependsRBAC, ], ) -async def delete_ai_providers(db: CurrentSessionTransaction, obj: DeleteAiProviderParam) -> ResponseModel: +async def delete_ai_providers(db: CurrentSessionTransaction, obj: DeleteAIProviderParam) -> ResponseModel: count = await ai_provider_service.delete(db=db, obj=obj) if count > 0: return response_base.success() diff --git a/crud/crud_model.py b/crud/crud_model.py new file mode 100644 index 0000000..300e3fe --- /dev/null +++ b/crud/crud_model.py @@ -0,0 +1,99 @@ +from collections.abc import Sequence +from typing import Any + +from sqlalchemy import Select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy_crud_plus import CRUDPlus + +from backend.plugin.ai.model import AIModel +from backend.plugin.ai.schema.model import CreateAIModelParam, UpdateAIModelParam + + +class CRUDAIModel(CRUDPlus[AIModel]): + async def get(self, db: AsyncSession, pk: int) -> AIModel | None: + """ + 获取模型 + + :param db: 数据库会话 + :param pk: 模型 ID + :return: + """ + return await self.select_model(db, pk) + + async def get_by_model_and_provider(self, db: AsyncSession, model_id: str, provider_id: int) -> AIModel | None: + """ + 通过模型和供应商获取模型 + + :param db: 数据库会话 + :param model_id: 模型 + :param provider_id: 供应商 + :return: + """ + return await self.select_model_by_column(db, model_id=model_id, provider_id=provider_id) + + async def get_select(self) -> Select: + """获取模型列表查询表达式""" + return await self.select_order('id', 'desc') + + async def get_all(self, db: AsyncSession) -> Sequence[AIModel]: + """ + 获取所有模型 + + :param db: 数据库会话 + :return: + """ + return await self.select_models(db) + + async def create(self, db: AsyncSession, obj: CreateAIModelParam) -> None: + """ + 创建模型 + + :param db: 数据库会话 + :param obj: 创建模型参数 + :return: + """ + await self.create_model(db, obj) + + async def bulk_create(self, db: AsyncSession, objs: list[dict[str, Any]]) -> None: + """ + 批量创建模型 + + :param db:数据库会话 + :param objs: 批量创建模型参数 + :return: + """ + await self.bulk_create_models(db, objs) + + async def update(self, db: AsyncSession, pk: int, obj: UpdateAIModelParam) -> int: + """ + 更新模型 + + :param db: 数据库会话 + :param pk: 模型 ID + :param obj: 更新 模型参数 + :return: + """ + return await self.update_model(db, pk, obj) + + async def delete(self, db: AsyncSession, pks: list[int]) -> int: + """ + 批量删除模型 + + :param db: 数据库会话 + :param pks: 模型 ID 列表 + :return: + """ + return await self.delete_model_by_column(db, allow_multiple=True, id__in=pks) + + async def delete_by_provider(self, db: AsyncSession, provider_id: int) -> int: + """ + 通过供应商 ID 删除模型 + + :param db: 数据库会话 + :param provider_id: 供应商 ID + :return: + """ + return await self.delete_model_by_column(db, allow_multiple=True, provider_id=provider_id) + + +ai_model_dao: CRUDAIModel = CRUDAIModel(AIModel) diff --git a/crud/crud_provider.py b/crud/crud_provider.py index 3b4cd3d..bf78f80 100644 --- a/crud/crud_provider.py +++ b/crud/crud_provider.py @@ -4,12 +4,12 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy_crud_plus import CRUDPlus -from backend.plugin.ai.model import AiProvider -from backend.plugin.ai.schema.provider import CreateAiProviderParam, UpdateAiProviderParam +from backend.plugin.ai.model import AIProvider +from backend.plugin.ai.schema.provider import CreateAIProviderParam, UpdateAIProviderParam -class CRUDAiProvider(CRUDPlus[AiProvider]): - async def get(self, db: AsyncSession, pk: int) -> AiProvider | None: +class CRUDAIProvider(CRUDPlus[AIProvider]): + async def get(self, db: AsyncSession, pk: int) -> AIProvider | None: """ 获取供应商 @@ -23,7 +23,7 @@ async def get_select(self) -> Select: """获取供应商列表查询表达式""" return await self.select_order('id', 'desc') - async def get_all(self, db: AsyncSession) -> Sequence[AiProvider]: + async def get_all(self, db: AsyncSession) -> Sequence[AIProvider]: """ 获取所有供应商 @@ -32,7 +32,7 @@ async def get_all(self, db: AsyncSession) -> Sequence[AiProvider]: """ return await self.select_models(db) - async def create(self, db: AsyncSession, obj: CreateAiProviderParam) -> None: + async def create(self, db: AsyncSession, obj: CreateAIProviderParam) -> None: """ 创建供应商 @@ -42,7 +42,7 @@ async def create(self, db: AsyncSession, obj: CreateAiProviderParam) -> None: """ await self.create_model(db, obj) - async def update(self, db: AsyncSession, pk: int, obj: UpdateAiProviderParam) -> int: + async def update(self, db: AsyncSession, pk: int, obj: UpdateAIProviderParam) -> int: """ 更新供应商 @@ -64,4 +64,4 @@ async def delete(self, db: AsyncSession, pks: list[int]) -> int: return await self.delete_model_by_column(db, allow_multiple=True, id__in=pks) -ai_provider_dao: CRUDAiProvider = CRUDAiProvider(AiProvider) +ai_provider_dao: CRUDAIProvider = CRUDAIProvider(AIProvider) diff --git a/enums.py b/enums.py new file mode 100644 index 0000000..a7e7108 --- /dev/null +++ b/enums.py @@ -0,0 +1,17 @@ +from backend.common.enums import IntEnum + + +class AIProviderType(IntEnum): + """AI 供应商类型""" + + openai = 0 + anthropic = 1 + gemini = 2 + bedrock = 3 + cerebras = 4 + cohere = 5 + groq = 6 + hugging_face = 7 + mistral = 8 + openrouter = 9 + outlines = 10 diff --git a/model/__init__.py b/model/__init__.py index e78bca8..89a244b 100644 --- a/model/__init__.py +++ b/model/__init__.py @@ -1 +1,2 @@ -from backend.plugin.ai.model.provider import AiProvider as AiProvider +from backend.plugin.ai.model.model import AIModel as AIModel +from backend.plugin.ai.model.provider import AIProvider as AIProvider diff --git a/model/model.py b/model/model.py new file mode 100644 index 0000000..05c5548 --- /dev/null +++ b/model/model.py @@ -0,0 +1,18 @@ +import sqlalchemy as sa + +from sqlalchemy.orm import Mapped, mapped_column + +from backend.common.model import Base, UniversalText, id_key + + +class AIModel(Base): + """AI 模型""" + + __tablename__ = 'ai_model' + + id: Mapped[id_key] = mapped_column(init=False) + provider_id: Mapped[int] = mapped_column(sa.BigInteger, comment='供应商关联 ID') + model_id: Mapped[str] = mapped_column(sa.String(512), comment='模型 ID') + owned_by: Mapped[str] = mapped_column(sa.String(512), comment='拥有该模型的组织') + status: Mapped[int] = mapped_column(default=1, comment='角色状态(0停用 1正常)') + remark: Mapped[str | None] = mapped_column(UniversalText, default=None, comment='备注') diff --git a/model/provider.py b/model/provider.py index 5cf2aab..6ea3df6 100644 --- a/model/provider.py +++ b/model/provider.py @@ -5,14 +5,16 @@ from backend.common.model import Base, UniversalText, id_key -class AiProvider(Base): +class AIProvider(Base): """AI 供应商""" __tablename__ = 'ai_provider' id: Mapped[id_key] = mapped_column(init=False) name: Mapped[str] = mapped_column(sa.String(256), comment='供应商名称') - type: Mapped[int] = mapped_column(comment='供应商类型(0OpenAI 1Anthropic 2Gemini)') + type: Mapped[int] = mapped_column( + comment='供应商类型(0:OpenAI 1:Anthropic 2:Gemini 3:Bedrock 4:Cerebras 5:Cohere 6:Groq 7:HuggingFace 8:Mistral 9:OpenRouter 10:Outlines)' # noqa: E501 + ) api_key: Mapped[str] = mapped_column(UniversalText, comment='API Key') api_host: Mapped[str] = mapped_column(sa.String(512), comment='API Host') status: Mapped[int] = mapped_column(default=1, comment='角色状态(0停用 1正常)') diff --git a/requirements.txt b/requirements.txt index a5aacfb..2c4dd36 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1 @@ -pydantic-ai-slim[openai,google,anthropic,groq,mcp]>=1.34.0 +pydantic-ai-slim[openai,anthropic,google,bedrock,cohere,groq,huggingface,mistral,outlines-vllm-offline,mcp]>=1.38.0 diff --git a/schema/chat.py b/schema/chat.py new file mode 100644 index 0000000..9ce7f15 --- /dev/null +++ b/schema/chat.py @@ -0,0 +1,28 @@ +from pydantic import Field + +from backend.common.schema import SchemaBase + + +class AIChatSchemaBase(SchemaBase): + """聊天基础模型""" + + max_tokens: int | None = Field(default=None, description='停止前最多可生成的 token 数') + temperature: float | None = Field(default=1.0, description='模型生成文本的随机性') + top_p: float | None = Field(default=1.0, description='模型生成文本的多样性') + timeout: float | None = Field(default=None, description='覆盖客户端对请求的默认超时(单位:s)') + parallel_tool_calls: bool | None = Field(default=True, description='是否允许并行工具调用') + seed: int | None = Field(default=None, description='用于模型的随机种子') + presence_penalty: float | None = Field(default=None, description='根据新 token 是否出现在文本中来处罚') + frequency_penalty: float | None = Field(default=None, description='根据新 token 目前在文本中的出现频率进行惩罚') + logit_bias: dict[str, int] | None = Field(default=None, description='修改完成中出现指定标记的可能性') + stop_sequences: list[str] | None = Field(default=None, description='这些序列会导致模型停止生成') + extra_headers: dict[str, str] | None = Field(default=None, description='发送给模型的额外 Headers') + extra_body: object | None = Field(default=None, description='发送给模型的额外请求体') + + +class AIChat(AIChatSchemaBase): + """聊天参数""" + + provider_id: int = Field(description='供应商 ID') + model_id: str = Field(description='聊天模型') + user_prompt: str = Field(description='用户提示词') diff --git a/schema/model.py b/schema/model.py new file mode 100644 index 0000000..81f0932 --- /dev/null +++ b/schema/model.py @@ -0,0 +1,40 @@ +from datetime import datetime + +from pydantic import ConfigDict, Field + +from backend.common.enums import StatusType +from backend.common.schema import SchemaBase + + +class AIModelSchemaBase(SchemaBase): + """AI 模型基础模型""" + + provider_id: int = Field(description='供应商 ID') + model_id: str = Field(description='模型 ID') + owned_by: str = Field(description='拥有该模型的组织') + status: StatusType = Field(description='状态') + remark: str | None = Field(default=None, description='备注') + + +class CreateAIModelParam(AIModelSchemaBase): + """创建 AI 模型参数""" + + +class UpdateAIModelParam(AIModelSchemaBase): + """更新 AI 模型参数""" + + +class DeleteAIModelParam(SchemaBase): + """删除 AI 模型参数""" + + pks: list[int] = Field(description='模型 ID 列表') + + +class GetAIModelDetail(AIModelSchemaBase): + """AI 模型详情""" + + model_config = ConfigDict(from_attributes=True) + + id: int + created_time: datetime + updated_time: datetime | None = None diff --git a/schema/provider.py b/schema/provider.py index c06b9a3..95db91e 100644 --- a/schema/provider.py +++ b/schema/provider.py @@ -4,38 +4,48 @@ from backend.common.enums import StatusType from backend.common.schema import SchemaBase +from backend.plugin.ai.enums import AIProviderType -class AiProviderSchemaBase(SchemaBase): - """供应商基础模型""" +class AIProviderSchemaBase(SchemaBase): + """AI 供应商基础模型""" name: str = Field(description='供应商名称') - type: int = Field(description='供应商类型(0OpenAI 1Anthropic 2Gemini)') + type: AIProviderType = Field(description='供应商类型(0OpenAI 1Anthropic 2Gemini)') api_key: str = Field(description='API Key') api_host: str = Field(description='API Host') status: StatusType = Field(description='状态') remark: str | None = Field(None, description='备注') -class CreateAiProviderParam(AiProviderSchemaBase): - """创建供应商参数""" +class CreateAIProviderParam(AIProviderSchemaBase): + """创建 AI 供应商参数""" -class UpdateAiProviderParam(AiProviderSchemaBase): - """更新供应商参数""" +class UpdateAIProviderParam(AIProviderSchemaBase): + """更新 AI 供应商参数""" -class DeleteAiProviderParam(SchemaBase): - """删除供应商参数""" +class DeleteAIProviderParam(SchemaBase): + """删除 AI 供应商参数""" pks: list[int] = Field(description='供应商 ID 列表') -class GetAiProviderDetail(AiProviderSchemaBase): - """供应商详情""" +class GetAIProviderDetail(AIProviderSchemaBase): + """AI 供应商详情""" model_config = ConfigDict(from_attributes=True) id: int created_time: datetime updated_time: datetime | None = None + + +class GetAIProviderModelDetail(SchemaBase): + """获取供应商模型详情""" + + id: str = Field(description='模型标识符') + object: str = Field(description='对象类型始终为 “model”') + created: int = Field(description='模型创建时的 Unix 时间戳(以秒为单位)') + owned_by: str = Field(description='拥有该模型的组织') diff --git a/service/chat_service.py b/service/chat_service.py new file mode 100644 index 0000000..0b950b2 --- /dev/null +++ b/service/chat_service.py @@ -0,0 +1,81 @@ +import json + +from collections.abc import AsyncGenerator +from typing import Any + +from pydantic_ai import Agent, ModelResponse, ModelSettings, TextPart +from sqlalchemy.ext.asyncio import AsyncSession + +from backend.common.exception import errors +from backend.plugin.ai.crud.crud_model import ai_model_dao +from backend.plugin.ai.crud.crud_provider import ai_provider_dao +from backend.plugin.ai.schema.chat import AIChat +from backend.plugin.ai.utils.message_parse import to_chat_message +from backend.plugin.ai.utils.model_control import get_pydantic_model + +chat_agent = Agent(name='fba_chat') + + +class ChatService: + """聊天服务类""" + + @staticmethod + async def stream_messages(*, db: AsyncSession, chat: AIChat) -> AsyncGenerator[bytes, Any]: + """ + 流式消息 + + :param db: 数据库会话 + :param chat: 聊天参数 + :return: + """ + provider = await ai_provider_dao.get(db, chat.provider_id) + if not provider: + raise errors.NotFoundError(msg='供应商不存在') + + if not provider.status: + raise errors.RequestError(msg='此供应商暂不可用,请更换供应商或联系系统管理员') + + model = await ai_model_dao.get_by_model_and_provider(db, chat.model_id, chat.provider_id) + if not model: + raise errors.NotFoundError(msg='供应商模型不存在') + + if not model.status: + raise errors.RequestError(msg='此模型暂不可用,请更换模型或联系系统管理员') + + yield json.dumps({'role': 'user', 'content': chat.user_prompt}, ensure_ascii=False).encode('utf-8') + b'\n' + + model_settings = { + k: v + for k, v in { + 'max_tokens': chat.max_tokens, + 'temperature': chat.temperature, + 'top_p': chat.top_p, + 'timeout': chat.timeout, + 'parallel_tool_calls': chat.parallel_tool_calls, + 'seed': chat.seed, + 'presence_penalty': chat.presence_penalty, + 'frequency_penalty': chat.frequency_penalty, + 'logit_bias': chat.logit_bias, + 'stop_sequences': chat.stop_sequences, + 'extra_headers': chat.extra_headers, + 'extra_body': chat.extra_body, + }.items() + if v is not None + } + + async with chat_agent.run_stream( + chat.user_prompt, + model=get_pydantic_model( + provider_type=provider.type, + model_name=model.model_id, + api_key=provider.api_key, + base_url=provider.api_host, + model_settings=ModelSettings(**model_settings), + ), + ) as result: + async for text in result.stream_output(debounce_by=0.01): + message = ModelResponse(parts=[TextPart(text)], model_name=model.model_id, timestamp=result.timestamp()) + yield json.dumps(to_chat_message(message)).encode('utf-8') + b'\n' + + +ai_chat_service: ChatService = ChatService() diff --git a/service/model_service.py b/service/model_service.py new file mode 100644 index 0000000..cb30894 --- /dev/null +++ b/service/model_service.py @@ -0,0 +1,85 @@ +from collections.abc import Sequence +from typing import Any + +from sqlalchemy.ext.asyncio import AsyncSession + +from backend.common.pagination import paging_data +from backend.plugin.ai.crud.crud_model import ai_model_dao +from backend.plugin.ai.model import AIModel +from backend.plugin.ai.schema.model import CreateAIModelParam, DeleteAIModelParam, UpdateAIModelParam + + +class AIModelService: + """AI 模型服务""" + + @staticmethod + async def get(*, db: AsyncSession, pk: int) -> AIModel | None: + """ + 获取 AI 模型 + + :param db: 数据库会话 + :param pk: 模型 ID + :return: + """ + await ai_model_dao.get(db, pk) + + @staticmethod + async def get_list(db: AsyncSession) -> dict[str, Any]: + """ + 获取 AI 模型列表 + + :param db: 数据库会话 + :return: + """ + ai_model_select = await ai_model_dao.get_select() + return await paging_data(db, ai_model_select) + + @staticmethod + async def get_all(*, db: AsyncSession) -> Sequence[AIModel]: + """ + 获取所有 AI 模型 + + :param db: 数据库会话 + :return: + """ + ai_providers = await ai_model_dao.get_all(db) + return ai_providers + + @staticmethod + async def create(*, db: AsyncSession, obj: CreateAIModelParam) -> None: + """ + 创建 AI 模型 + + :param db: 数据库会话 + :param obj: 创建模型参数 + :return: + """ + await ai_model_dao.create(db, obj) + + @staticmethod + async def update(*, db: AsyncSession, pk: int, obj: UpdateAIModelParam) -> int: + """ + 更新 AI 模型 + + :param db: 数据库会话 + :param pk: 模型 ID + :param obj: 更新模型参数 + :return: + """ + count = await ai_model_dao.update(db, pk, obj) + return count + + @staticmethod + async def delete(*, db: AsyncSession, obj: DeleteAIModelParam) -> int: + """ + 删除 AI 模型 + + :param db: 数据库会话 + :param obj: 模型 ID 列表 + :return: + """ + count = await ai_model_dao.delete(db, obj.pks) + return count + + +ai_model_service: AIModelService = AIModelService() diff --git a/service/provider_service.py b/service/provider_service.py index f5b4a34..c7d1639 100644 --- a/service/provider_service.py +++ b/service/provider_service.py @@ -1,20 +1,32 @@ from collections.abc import Sequence from typing import Any +import httpx + from sqlalchemy.ext.asyncio import AsyncSession +from backend.common.enums import StatusType from backend.common.exception import errors +from backend.common.log import log from backend.common.pagination import paging_data +from backend.plugin.ai.crud.crud_model import ai_model_dao from backend.plugin.ai.crud.crud_provider import ai_provider_dao -from backend.plugin.ai.model import AiProvider -from backend.plugin.ai.schema.provider import CreateAiProviderParam, DeleteAiProviderParam, UpdateAiProviderParam - - -class AiProviderService: +from backend.plugin.ai.model import AIProvider +from backend.plugin.ai.schema.model import CreateAIModelParam +from backend.plugin.ai.schema.provider import ( + CreateAIProviderParam, + DeleteAIProviderParam, + GetAIProviderModelDetail, + UpdateAIProviderParam, +) +from backend.utils.timezone import timezone + + +class AIProviderService: @staticmethod - async def get(*, db: AsyncSession, pk: int) -> AiProvider: + async def get(*, db: AsyncSession, pk: int) -> AIProvider: """ - 获取供应商 + 获取 AI 供应商 :param db: 数据库会话 :param pk: 供应商 ID @@ -25,10 +37,51 @@ async def get(*, db: AsyncSession, pk: int) -> AiProvider: raise errors.NotFoundError(msg='供应商不存在') return ai_provider + async def get_models(self, *, db: AsyncSession, pk: int) -> list[GetAIProviderModelDetail]: + """获取供应商模型""" + ai_provider = await self.get(db=db, pk=pk) + async with httpx.AsyncClient(timeout=10) as client: + url = f'{ai_provider.api_host}/v1/models' + headers = {'Authorization': f'Bearer {ai_provider.api_key}'} + try: + response = await client.get(url, headers=headers) + response.raise_for_status() + except Exception as e: + log.error(f'获取供应商模型列表失败:{e}') + raise errors.ForbiddenError(msg='获取供应商模型列表失败,请稍后重试') + else: + return [GetAIProviderModelDetail(**data) for data in response.json()['data']] + + async def sync_models(self, *, db: AsyncSession, pk: int) -> None: + """ + 同步供应商模型 + + :param db: 数据库会话 + :param pk: 供应商 ID + :return: + """ + provider_models = await self.get_models(db=db, pk=pk) + await ai_model_dao.delete_by_provider(db, pk) + await ai_model_dao.bulk_create( + db, + [ + { + **CreateAIModelParam( + provider_id=pk, + model_id=obj.id, + owned_by=obj.owned_by, + status=StatusType.enable, + ).model_dump(), + 'created_time': timezone.now(), + } + for obj in provider_models + ], + ) + @staticmethod async def get_list(db: AsyncSession) -> dict[str, Any]: """ - 获取供应商列表 + 获取 AI 供应商列表 :param db: 数据库会话 :return: @@ -37,9 +90,9 @@ async def get_list(db: AsyncSession) -> dict[str, Any]: return await paging_data(db, ai_provider_select) @staticmethod - async def get_all(*, db: AsyncSession) -> Sequence[AiProvider]: + async def get_all(*, db: AsyncSession) -> Sequence[AIProvider]: """ - 获取所有供应商 + 获取所有 AI 供应商 :param db: 数据库会话 :return: @@ -48,20 +101,22 @@ async def get_all(*, db: AsyncSession) -> Sequence[AiProvider]: return ai_providers @staticmethod - async def create(*, db: AsyncSession, obj: CreateAiProviderParam) -> None: + async def create(*, db: AsyncSession, obj: CreateAIProviderParam) -> None: """ - 创建供应商 + 创建 AI 供应商 :param db: 数据库会话 :param obj: 创建供应商参数 :return: """ + if obj.api_host.endswith('/'): + raise errors.RequestError(msg='API 请求地址不能以 `/` 结尾') await ai_provider_dao.create(db, obj) @staticmethod - async def update(*, db: AsyncSession, pk: int, obj: UpdateAiProviderParam) -> int: + async def update(*, db: AsyncSession, pk: int, obj: UpdateAIProviderParam) -> int: """ - 更新供应商 + 更新 AI 供应商 :param db: 数据库会话 :param pk: 供应商 ID @@ -72,9 +127,9 @@ async def update(*, db: AsyncSession, pk: int, obj: UpdateAiProviderParam) -> in return count @staticmethod - async def delete(*, db: AsyncSession, obj: DeleteAiProviderParam) -> int: + async def delete(*, db: AsyncSession, obj: DeleteAIProviderParam) -> int: """ - 删除供应商 + 删除 AI 供应商 :param db: 数据库会话 :param obj: 供应商 ID 列表 @@ -84,4 +139,4 @@ async def delete(*, db: AsyncSession, obj: DeleteAiProviderParam) -> int: return count -ai_provider_service: AiProviderService = AiProviderService() +ai_provider_service: AIProviderService = AIProviderService() diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/utils/message_parse.py b/utils/message_parse.py new file mode 100644 index 0000000..397bf22 --- /dev/null +++ b/utils/message_parse.py @@ -0,0 +1,32 @@ +from typing import Literal, TypedDict + +from pydantic_ai import ModelMessage, ModelRequest, ModelResponse, TextPart, UserPromptPart + +from backend.common.exception import errors + + +class ChatMessage(TypedDict): + """发送给浏览器的消息格式""" + + role: Literal['user', 'model'] + timestamp: str + content: str + + +def to_chat_message(message: ModelMessage) -> ChatMessage: + first_part = message.parts[0] + if isinstance(message, ModelRequest): + if isinstance(first_part, UserPromptPart): + assert isinstance(first_part.content, str) + return { + 'role': 'user', + 'timestamp': first_part.timestamp.isoformat(), + 'content': first_part.content, + } + elif isinstance(message, ModelResponse) and isinstance(first_part, TextPart): + return { + 'role': 'model', + 'timestamp': message.timestamp.isoformat(), + 'content': first_part.content, + } + raise errors.NotFoundError(msg=f'消息类型错误: {message}') diff --git a/utils/model_control.py b/utils/model_control.py new file mode 100644 index 0000000..b6077de --- /dev/null +++ b/utils/model_control.py @@ -0,0 +1,127 @@ +from openai import AsyncOpenAI +from pydantic_ai import ModelSettings +from pydantic_ai.models.anthropic import AnthropicModel +from pydantic_ai.models.bedrock import BedrockConverseModel +from pydantic_ai.models.cerebras import CerebrasModel +from pydantic_ai.models.cohere import CohereModel +from pydantic_ai.models.google import GoogleModel +from pydantic_ai.models.groq import GroqModel +from pydantic_ai.models.huggingface import HuggingFaceModel +from pydantic_ai.models.mistral import MistralModel +from pydantic_ai.models.openai import OpenAIChatModel +from pydantic_ai.models.openrouter import OpenRouterModel +from pydantic_ai.models.outlines import OutlinesModel +from pydantic_ai.providers.anthropic import AnthropicProvider +from pydantic_ai.providers.bedrock import BedrockProvider +from pydantic_ai.providers.cerebras import CerebrasProvider +from pydantic_ai.providers.cohere import CohereProvider +from pydantic_ai.providers.google import GoogleProvider +from pydantic_ai.providers.groq import GroqProvider +from pydantic_ai.providers.huggingface import HuggingFaceProvider +from pydantic_ai.providers.mistral import MistralProvider +from pydantic_ai.providers.openai import OpenAIProvider +from pydantic_ai.providers.openrouter import OpenRouterProvider +from pydantic_ai.providers.outlines import OutlinesProvider + +from backend.common.exception import errors +from backend.plugin.ai.enums import AIProviderType + + +def get_pydantic_model( # noqa: C901 + provider_type: int, model_name: str, api_key: str, base_url: str, model_settings: ModelSettings +) -> ( + OpenAIChatModel + | AnthropicModel + | GoogleModel + | BedrockConverseModel + | CerebrasModel + | CohereModel + | GroqModel + | HuggingFaceModel + | MistralModel + | OpenRouterModel + | OutlinesModel +): + """ + 获取 pydantic 模型 + + :param provider_type: 供应商类型 + :param model_name: 模型名称 + :param api_key: 密钥 + :param base_url: API 基础域名 + :param model_settings: 模型配置 + :return: + """ + base_url = base_url.rstrip('/') if base_url else None + + if provider_type == AIProviderType.openai: + openai_base_url = None + if base_url: + openai_base_url = f'{base_url}/v1' if not base_url.endswith('/v1') else base_url + return OpenAIChatModel( + model_name, provider=OpenAIProvider(base_url=openai_base_url, api_key=api_key), settings=model_settings + ) + + if provider_type == AIProviderType.anthropic: + anthropic_base_url = None + if base_url: + anthropic_base_url = f'{base_url}/v1' if not base_url.endswith('/v1') else base_url + return AnthropicModel( + model_name, + provider=AnthropicProvider(base_url=anthropic_base_url, api_key=api_key), + settings=model_settings, + ) + + if provider_type == AIProviderType.gemini: + google_base_url = None + if base_url: + google_base_url = f'{base_url}/v1beta/openai' if not base_url.endswith('/v1beta/openai') else base_url + return GoogleModel( + model_name, provider=GoogleProvider(base_url=google_base_url, api_key=api_key), settings=model_settings + ) + + if provider_type == AIProviderType.bedrock: + region = base_url if base_url and not base_url.startswith('http') else 'us-east-1' + return BedrockConverseModel( + model_name, # type: ignore[arg-type] + provider=BedrockProvider(api_key=api_key, region_name=region), + settings=model_settings, + ) + + if provider_type == AIProviderType.cerebras: + return CerebrasModel(model_name, provider=CerebrasProvider(api_key=api_key), settings=model_settings) # type: ignore[arg-type] + + if provider_type == AIProviderType.cohere: + return CohereModel(model_name, provider=CohereProvider(api_key=api_key), settings=model_settings) + + if provider_type == AIProviderType.groq: + groq_base_url = None + if base_url: + groq_base_url = f'{base_url}/openai/v1' if not base_url.endswith('/openai/v1') else base_url + return GroqModel( + model_name, provider=GroqProvider(api_key=api_key, base_url=groq_base_url), settings=model_settings + ) + + if provider_type == AIProviderType.hugging_face: + return HuggingFaceModel( + model_name, provider=HuggingFaceProvider(base_url=base_url, api_key=api_key), settings=model_settings + ) + + if provider_type == AIProviderType.mistral: + return MistralModel(model_name, provider=MistralProvider(api_key=api_key), settings=model_settings) + + if provider_type == AIProviderType.openrouter: + openrouter_base_url = None + if base_url: + openrouter_base_url = f'{base_url}/api/v1' if not base_url.endswith('/api/v1') else base_url + openai_client = AsyncOpenAI(base_url=openrouter_base_url, api_key=api_key) if openrouter_base_url else None + return OpenRouterModel( + model_name, + provider=OpenRouterProvider(api_key=api_key, openai_client=openai_client), + settings=model_settings, + ) + + if provider_type == AIProviderType.outlines: + return OutlinesModel(model_name, provider=OutlinesProvider(), settings=model_settings) # type: ignore[arg-type] + + raise errors.NotFoundError(msg=f'不支持的供应商类型: {provider_type}')