diff --git a/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/reranker.py b/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/reranker.py index 61fe4075db2..9b30aaf14c0 100644 --- a/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/reranker.py +++ b/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/reranker.py @@ -7,12 +7,22 @@ from common import forms from common.exception.app_exception import AppApiException -from common.forms import BaseForm, PasswordInputField +from common.forms import BaseForm, PasswordInputField, TooltipLabel from models_provider.base_model_provider import BaseModelCredential, ValidCode from models_provider.impl.aliyun_bai_lian_model_provider.model.reranker import AliyunBaiLianReranker from common.utils.logger import maxkb_logger +class AliyunRerankerModelParams(BaseForm): + top_n = forms.SliderField(TooltipLabel(_('Top N'), + _('Number of top documents to return after reranking')), + required=True, default_value=3, + _min=1, + _max=100, + _step=1, + precision=0) + + class AliyunBaiLianRerankerCredential(BaseForm, BaseModelCredential): """ Credential class for the Aliyun BaiLian Reranker model. @@ -86,3 +96,6 @@ def encryption_dict(self, model: Dict[str, object]) -> Dict[str, object]: **model, 'dashscope_api_key': super().encryption(model.get('dashscope_api_key', '')) } + + def get_model_params_setting_form(self, model_name: str) -> AliyunRerankerModelParams: + return AliyunRerankerModelParams() diff --git a/apps/models_provider/impl/docker_ai_model_provider/credential/reranker.py b/apps/models_provider/impl/docker_ai_model_provider/credential/reranker.py index f89ffa2639b..98ebbd252db 100644 --- a/apps/models_provider/impl/docker_ai_model_provider/credential/reranker.py +++ b/apps/models_provider/impl/docker_ai_model_provider/credential/reranker.py @@ -13,11 +13,22 @@ from common import forms from common.exception.app_exception import AppApiException -from common.forms import BaseForm +from common.forms import BaseForm, TooltipLabel from models_provider.base_model_provider import BaseModelCredential, ValidCode from models_provider.impl.docker_ai_model_provider.model.reranker import DockerAIReranker from common.utils.logger import maxkb_logger + +class DockerAIRerankerModelParams(BaseForm): + top_n = forms.SliderField(TooltipLabel(_('Top N'), + _('Number of top documents to return after reranking')), + required=True, default_value=3, + _min=1, + _max=100, + _step=1, + precision=0) + + class DockerAIRerankerCredential(BaseForm, BaseModelCredential): def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, @@ -48,4 +59,8 @@ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, obje def encryption_dict(self, model: Dict[str, object]): return {**model} + api_base = forms.TextInputField('API URL', required=True) + + def get_model_params_setting_form(self, model_name: str) -> DockerAIRerankerModelParams: + return DockerAIRerankerModelParams() diff --git a/apps/models_provider/impl/local_model_provider/credential/reranker/model.py b/apps/models_provider/impl/local_model_provider/credential/reranker/model.py index 3c6fa4e327a..9d381747ba6 100644 --- a/apps/models_provider/impl/local_model_provider/credential/reranker/model.py +++ b/apps/models_provider/impl/local_model_provider/credential/reranker/model.py @@ -12,12 +12,23 @@ from common import forms from common.exception.app_exception import AppApiException -from common.forms import BaseForm +from common.forms import BaseForm, TooltipLabel from models_provider.base_model_provider import BaseModelCredential, ValidCode from models_provider.impl.local_model_provider.model.reranker import LocalReranker from django.utils.translation import gettext_lazy as _, gettext from common.utils.logger import maxkb_logger + +class LocalRerankerModelParams(BaseForm): + top_n = forms.SliderField(TooltipLabel(_('Top N'), + _('Number of top documents to return after reranking')), + required=True, default_value=3, + _min=1, + _max=100, + _step=1, + precision=0) + + class LocalRerankerCredential(BaseForm, BaseModelCredential): def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, @@ -51,3 +62,6 @@ def encryption_dict(self, model: Dict[str, object]): return model cache_dir = forms.TextInputField(_('Model catalog'), required=True) + + def get_model_params_setting_form(self, model_name: str) -> LocalRerankerModelParams: + return LocalRerankerModelParams() diff --git a/apps/models_provider/impl/ollama_model_provider/credential/reranker.py b/apps/models_provider/impl/ollama_model_provider/credential/reranker.py index d82f3cfee51..0fb0bd1e37c 100644 --- a/apps/models_provider/impl/ollama_model_provider/credential/reranker.py +++ b/apps/models_provider/impl/ollama_model_provider/credential/reranker.py @@ -9,14 +9,24 @@ from typing import Dict from django.utils.translation import gettext as _ +from django.utils.translation import gettext_lazy as _, gettext +from langchain_core.documents import Document from common import forms from common.exception.app_exception import AppApiException -from common.forms import BaseForm +from common.forms import BaseForm, TooltipLabel from models_provider.base_model_provider import BaseModelCredential, ValidCode from models_provider.impl.ollama_model_provider.model.reranker import OllamaReranker -from langchain_core.documents import BaseDocumentCompressor, Document -from django.utils.translation import gettext_lazy as _, gettext + + +class OllamaRerankerModelParams(BaseForm): + top_n = forms.SliderField(TooltipLabel(_('Top N'), + _('Number of top documents to return after reranking')), + required=True, default_value=3, + _min=1, + _max=100, + _step=1, + precision=0) class OllamaReRankModelCredential(BaseForm, BaseModelCredential): @@ -64,3 +74,6 @@ def build_model(self, model_info: Dict[str, object]): return self api_base = forms.TextInputField('API URL', required=True) + + def get_model_params_setting_form(self, model_name: str) -> OllamaRerankerModelParams: + return OllamaRerankerModelParams() diff --git a/apps/models_provider/impl/siliconCloud_model_provider/credential/reranker.py b/apps/models_provider/impl/siliconCloud_model_provider/credential/reranker.py index 7a0b17ba6f0..e49de9f583e 100644 --- a/apps/models_provider/impl/siliconCloud_model_provider/credential/reranker.py +++ b/apps/models_provider/impl/siliconCloud_model_provider/credential/reranker.py @@ -13,11 +13,22 @@ from common import forms from common.exception.app_exception import AppApiException -from common.forms import BaseForm +from common.forms import BaseForm, TooltipLabel from models_provider.base_model_provider import BaseModelCredential, ValidCode from models_provider.impl.siliconCloud_model_provider.model.reranker import SiliconCloudReranker from common.utils.logger import maxkb_logger + +class SiliconCloudRerankerModelParams(BaseForm): + top_n = forms.SliderField(TooltipLabel(_('Top N'), + _('Number of top documents to return after reranking')), + required=True, default_value=3, + _min=1, + _max=100, + _step=1, + precision=0) + + class SiliconCloudRerankerCredential(BaseForm, BaseModelCredential): def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, @@ -48,5 +59,9 @@ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, obje def encryption_dict(self, model: Dict[str, object]): return {**model, 'api_key': super().encryption(model.get('api_key', ''))} + api_base = forms.TextInputField('API URL', required=True) api_key = forms.PasswordInputField('API Key', required=True) + + def get_model_params_setting_form(self, model_name: str) -> SiliconCloudRerankerModelParams: + return SiliconCloudRerankerModelParams() diff --git a/apps/models_provider/impl/vllm_model_provider/credential/reranker.py b/apps/models_provider/impl/vllm_model_provider/credential/reranker.py index 881c85179c9..d092704a71a 100644 --- a/apps/models_provider/impl/vllm_model_provider/credential/reranker.py +++ b/apps/models_provider/impl/vllm_model_provider/credential/reranker.py @@ -4,13 +4,24 @@ from common import forms from common.exception.app_exception import AppApiException -from common.forms import BaseForm +from common.forms import BaseForm, TooltipLabel from models_provider.base_model_provider import BaseModelCredential, ValidCode from django.utils.translation import gettext_lazy as _ from models_provider.impl.vllm_model_provider.model.reranker import VllmBgeReranker from common.utils.logger import maxkb_logger + +class VllmRerankerModelParams(BaseForm): + top_n = forms.SliderField(TooltipLabel(_('Top N'), + _('Number of top documents to return after reranking')), + required=True, default_value=3, + _min=1, + _max=100, + _step=1, + precision=0) + + class VllmRerankerCredential(BaseForm, BaseModelCredential): api_url = forms.TextInputField('API URL', required=True) api_key = forms.PasswordInputField('API Key', required=True) @@ -47,4 +58,7 @@ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, obje return True def encryption_dict(self, model_info: Dict[str, object]): - return {**model_info, 'api_key': super().encryption(model_info.get('api_key', ''))} \ No newline at end of file + return {**model_info, 'api_key': super().encryption(model_info.get('api_key', ''))} + + def get_model_params_setting_form(self, model_name: str) -> VllmRerankerModelParams: + return VllmRerankerModelParams() diff --git a/apps/models_provider/impl/wenxin_model_provider/credential/reranker.py b/apps/models_provider/impl/wenxin_model_provider/credential/reranker.py index 140101522d0..874644bb167 100644 --- a/apps/models_provider/impl/wenxin_model_provider/credential/reranker.py +++ b/apps/models_provider/impl/wenxin_model_provider/credential/reranker.py @@ -4,13 +4,23 @@ from common import forms from common.exception.app_exception import AppApiException -from common.forms import BaseForm +from common.forms import BaseForm, TooltipLabel from models_provider.base_model_provider import BaseModelCredential, ValidCode from django.utils.translation import gettext_lazy as _ from common.utils.logger import maxkb_logger from models_provider.impl.wenxin_model_provider.model.reranker import QfBgeReranker +class QfRerankerModelParams(BaseForm): + top_n = forms.SliderField(TooltipLabel(_('Top N'), + _('Number of top documents to return after reranking')), + required=True, default_value=3, + _min=1, + _max=100, + _step=1, + precision=0) + + class QfRerankerCredential(BaseForm, BaseModelCredential): api_url = forms.TextInputField('API URL', required=True) api_key = forms.PasswordInputField('API Key', required=True) @@ -47,4 +57,7 @@ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, obje return True def encryption_dict(self, model_info: Dict[str, object]): - return {**model_info, 'api_key': super().encryption(model_info.get('api_key', ''))} \ No newline at end of file + return {**model_info, 'api_key': super().encryption(model_info.get('api_key', ''))} + + def get_model_params_setting_form(self, model_name: str) -> QfRerankerModelParams: + return QfRerankerModelParams() diff --git a/apps/models_provider/impl/xinference_model_provider/credential/reranker.py b/apps/models_provider/impl/xinference_model_provider/credential/reranker.py index af1026cb102..94291f10549 100644 --- a/apps/models_provider/impl/xinference_model_provider/credential/reranker.py +++ b/apps/models_provider/impl/xinference_model_provider/credential/reranker.py @@ -13,10 +13,20 @@ from common import forms from common.exception.app_exception import AppApiException -from common.forms import BaseForm +from common.forms import BaseForm, TooltipLabel from models_provider.base_model_provider import BaseModelCredential, ValidCode +class XInferenceRerankerModelParams(BaseForm): + top_n = forms.SliderField(TooltipLabel(_('Top N'), + _('Number of top documents to return after reranking')), + required=True, default_value=3, + _min=1, + _max=100, + _step=1, + precision=0) + + class XInferenceRerankerModelCredential(BaseForm, BaseModelCredential): def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, raise_exception=True): @@ -49,3 +59,6 @@ def encryption_dict(self, model_info: Dict[str, object]): server_url = forms.TextInputField('API URL', required=True) api_key = forms.PasswordInputField('API Key', required=False) + + def get_model_params_setting_form(self, model_name: str) -> XInferenceRerankerModelParams: + return XInferenceRerankerModelParams()