Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,22 @@

from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm, PasswordInputField
from common.forms import BaseForm, PasswordInputField, TooltipLabel
from models_provider.base_model_provider import BaseModelCredential, ValidCode
from models_provider.impl.aliyun_bai_lian_model_provider.model.reranker import AliyunBaiLianReranker
from common.utils.logger import maxkb_logger


class AliyunRerankerModelParams(BaseForm):
top_n = forms.SliderField(TooltipLabel(_('Top N'),
_('Number of top documents to return after reranking')),
required=True, default_value=3,
_min=1,
_max=100,
_step=1,
precision=0)


class AliyunBaiLianRerankerCredential(BaseForm, BaseModelCredential):
"""
Credential class for the Aliyun BaiLian Reranker model.
Expand Down Expand Up @@ -86,3 +96,6 @@ def encryption_dict(self, model: Dict[str, object]) -> Dict[str, object]:
**model,
'dashscope_api_key': super().encryption(model.get('dashscope_api_key', ''))
}

def get_model_params_setting_form(self, model_name: str) -> AliyunRerankerModelParams:
return AliyunRerankerModelParams()
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,22 @@

from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from common.forms import BaseForm, TooltipLabel
from models_provider.base_model_provider import BaseModelCredential, ValidCode
from models_provider.impl.docker_ai_model_provider.model.reranker import DockerAIReranker
from common.utils.logger import maxkb_logger


class DockerAIRerankerModelParams(BaseForm):
top_n = forms.SliderField(TooltipLabel(_('Top N'),
_('Number of top documents to return after reranking')),
required=True, default_value=3,
_min=1,
_max=100,
_step=1,
precision=0)


class DockerAIRerankerCredential(BaseForm, BaseModelCredential):

def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
Expand Down Expand Up @@ -48,4 +59,8 @@ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, obje

def encryption_dict(self, model: Dict[str, object]):
return {**model}

api_base = forms.TextInputField('API URL', required=True)

def get_model_params_setting_form(self, model_name: str) -> DockerAIRerankerModelParams:
return DockerAIRerankerModelParams()
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,23 @@

from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from common.forms import BaseForm, TooltipLabel
from models_provider.base_model_provider import BaseModelCredential, ValidCode
from models_provider.impl.local_model_provider.model.reranker import LocalReranker
from django.utils.translation import gettext_lazy as _, gettext
from common.utils.logger import maxkb_logger


class LocalRerankerModelParams(BaseForm):
top_n = forms.SliderField(TooltipLabel(_('Top N'),
_('Number of top documents to return after reranking')),
required=True, default_value=3,
_min=1,
_max=100,
_step=1,
precision=0)


class LocalRerankerCredential(BaseForm, BaseModelCredential):

def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
Expand Down Expand Up @@ -51,3 +62,6 @@ def encryption_dict(self, model: Dict[str, object]):
return model

cache_dir = forms.TextInputField(_('Model catalog'), required=True)

def get_model_params_setting_form(self, model_name: str) -> LocalRerankerModelParams:
return LocalRerankerModelParams()
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,24 @@
from typing import Dict

from django.utils.translation import gettext as _
from django.utils.translation import gettext_lazy as _, gettext
from langchain_core.documents import Document

from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from common.forms import BaseForm, TooltipLabel
from models_provider.base_model_provider import BaseModelCredential, ValidCode
from models_provider.impl.ollama_model_provider.model.reranker import OllamaReranker
from langchain_core.documents import BaseDocumentCompressor, Document
from django.utils.translation import gettext_lazy as _, gettext


class OllamaRerankerModelParams(BaseForm):
top_n = forms.SliderField(TooltipLabel(_('Top N'),
_('Number of top documents to return after reranking')),
required=True, default_value=3,
_min=1,
_max=100,
_step=1,
precision=0)


class OllamaReRankModelCredential(BaseForm, BaseModelCredential):
Expand Down Expand Up @@ -64,3 +74,6 @@ def build_model(self, model_info: Dict[str, object]):
return self

api_base = forms.TextInputField('API URL', required=True)

def get_model_params_setting_form(self, model_name: str) -> OllamaRerankerModelParams:
return OllamaRerankerModelParams()
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,22 @@

from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from common.forms import BaseForm, TooltipLabel
from models_provider.base_model_provider import BaseModelCredential, ValidCode
from models_provider.impl.siliconCloud_model_provider.model.reranker import SiliconCloudReranker
from common.utils.logger import maxkb_logger


class SiliconCloudRerankerModelParams(BaseForm):
top_n = forms.SliderField(TooltipLabel(_('Top N'),
_('Number of top documents to return after reranking')),
required=True, default_value=3,
_min=1,
_max=100,
_step=1,
precision=0)


class SiliconCloudRerankerCredential(BaseForm, BaseModelCredential):

def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
Expand Down Expand Up @@ -48,5 +59,9 @@ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, obje

def encryption_dict(self, model: Dict[str, object]):
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}

api_base = forms.TextInputField('API URL', required=True)
api_key = forms.PasswordInputField('API Key', required=True)

def get_model_params_setting_form(self, model_name: str) -> SiliconCloudRerankerModelParams:
return SiliconCloudRerankerModelParams()
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,24 @@

from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from common.forms import BaseForm, TooltipLabel
from models_provider.base_model_provider import BaseModelCredential, ValidCode
from django.utils.translation import gettext_lazy as _

from models_provider.impl.vllm_model_provider.model.reranker import VllmBgeReranker
from common.utils.logger import maxkb_logger


class VllmRerankerModelParams(BaseForm):
top_n = forms.SliderField(TooltipLabel(_('Top N'),
_('Number of top documents to return after reranking')),
required=True, default_value=3,
_min=1,
_max=100,
_step=1,
precision=0)


class VllmRerankerCredential(BaseForm, BaseModelCredential):
api_url = forms.TextInputField('API URL', required=True)
api_key = forms.PasswordInputField('API Key', required=True)
Expand Down Expand Up @@ -47,4 +58,7 @@ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, obje
return True

def encryption_dict(self, model_info: Dict[str, object]):
return {**model_info, 'api_key': super().encryption(model_info.get('api_key', ''))}
return {**model_info, 'api_key': super().encryption(model_info.get('api_key', ''))}

def get_model_params_setting_form(self, model_name: str) -> VllmRerankerModelParams:
return VllmRerankerModelParams()
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,23 @@

from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from common.forms import BaseForm, TooltipLabel
from models_provider.base_model_provider import BaseModelCredential, ValidCode
from django.utils.translation import gettext_lazy as _
from common.utils.logger import maxkb_logger
from models_provider.impl.wenxin_model_provider.model.reranker import QfBgeReranker


class QfRerankerModelParams(BaseForm):
top_n = forms.SliderField(TooltipLabel(_('Top N'),
_('Number of top documents to return after reranking')),
required=True, default_value=3,
_min=1,
_max=100,
_step=1,
precision=0)


class QfRerankerCredential(BaseForm, BaseModelCredential):
api_url = forms.TextInputField('API URL', required=True)
api_key = forms.PasswordInputField('API Key', required=True)
Expand Down Expand Up @@ -47,4 +57,7 @@ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, obje
return True

def encryption_dict(self, model_info: Dict[str, object]):
return {**model_info, 'api_key': super().encryption(model_info.get('api_key', ''))}
return {**model_info, 'api_key': super().encryption(model_info.get('api_key', ''))}

def get_model_params_setting_form(self, model_name: str) -> QfRerankerModelParams:
return QfRerankerModelParams()
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,20 @@

from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from common.forms import BaseForm, TooltipLabel
from models_provider.base_model_provider import BaseModelCredential, ValidCode


class XInferenceRerankerModelParams(BaseForm):
top_n = forms.SliderField(TooltipLabel(_('Top N'),
_('Number of top documents to return after reranking')),
required=True, default_value=3,
_min=1,
_max=100,
_step=1,
precision=0)


class XInferenceRerankerModelCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
raise_exception=True):
Expand Down Expand Up @@ -49,3 +59,6 @@ def encryption_dict(self, model_info: Dict[str, object]):
server_url = forms.TextInputField('API URL', required=True)

api_key = forms.PasswordInputField('API Key', required=False)

def get_model_params_setting_form(self, model_name: str) -> XInferenceRerankerModelParams:
return XInferenceRerankerModelParams()
Loading