Skip to content

工具

lazyllm.tools.IntentClassifier

Bases: ModuleBase

意图分类模块,用于根据输入文本在给定的意图列表中进行分类。
支持中英文自动选择提示模板,并可通过示例、提示、约束和注意事项增强分类效果。

Parameters:

  • llm

    用于意图分类的大语言模型实例。

  • intent_list (list, default: None ) –

    可选,意图类别列表,例如 ["聊天", "天气", "问答"]。

  • prompt (str, default: '' ) –

    可选,自定义提示语,插入到系统提示模板中。

  • constrain (str, default: '' ) –

    可选,分类约束条件说明。

  • attention (str, default: '' ) –

    可选,提示注意事项。

  • examples (list[list[str, str]], default: None ) –

    可选,分类示例列表,每个元素为 [输入文本, 标签]。

  • return_trace (bool, default: False ) –

    是否返回执行过程的 trace,默认为 False。

Examples:

>>> import lazyllm
>>> from lazyllm.tools import IntentClassifier
>>> classifier_llm = lazyllm.OnlineChatModule(source="openai")
>>> chatflow_intent_list = ["Chat", "Financial Knowledge Q&A", "Employee Information Query", "Weather Query"]
>>> classifier = IntentClassifier(classifier_llm, intent_list=chatflow_intent_list)
>>> classifier.start()
>>> print(classifier('What is the weather today'))
Weather Query
>>>
>>> with IntentClassifier(classifier_llm) as ic:
>>>     ic.case['Weather Query', lambda x: '38.5°C']
>>>     ic.case['Chat', lambda x: 'permission denied']
>>>     ic.case['Financial Knowledge Q&A', lambda x: 'Calling Financial RAG']
>>>     ic.case['Employee Information Query', lambda x: 'Beijing']
...
>>> ic.start()
>>> print(ic('What is the weather today'))
38.5°C
Source code in lazyllm/tools/classifier/intent_classifier.py
class IntentClassifier(ModuleBase):
    """意图分类模块,用于根据输入文本在给定的意图列表中进行分类。  
支持中英文自动选择提示模板,并可通过示例、提示、约束和注意事项增强分类效果。

Args:
    llm: 用于意图分类的大语言模型实例。
    intent_list (list): 可选,意图类别列表,例如 ["聊天", "天气", "问答"]。
    prompt (str): 可选,自定义提示语,插入到系统提示模板中。
    constrain (str): 可选,分类约束条件说明。
    attention (str): 可选,提示注意事项。
    examples (list[list[str, str]]): 可选,分类示例列表,每个元素为 [输入文本, 标签]。
    return_trace (bool): 是否返回执行过程的 trace,默认为 False。


Examples:
        >>> import lazyllm
        >>> from lazyllm.tools import IntentClassifier
        >>> classifier_llm = lazyllm.OnlineChatModule(source="openai")
        >>> chatflow_intent_list = ["Chat", "Financial Knowledge Q&A", "Employee Information Query", "Weather Query"]
        >>> classifier = IntentClassifier(classifier_llm, intent_list=chatflow_intent_list)
        >>> classifier.start()
        >>> print(classifier('What is the weather today'))
        Weather Query
        >>>
        >>> with IntentClassifier(classifier_llm) as ic:
        >>>     ic.case['Weather Query', lambda x: '38.5°C']
        >>>     ic.case['Chat', lambda x: 'permission denied']
        >>>     ic.case['Financial Knowledge Q&A', lambda x: 'Calling Financial RAG']
        >>>     ic.case['Employee Information Query', lambda x: 'Beijing']
        ...
        >>> ic.start()
        >>> print(ic('What is the weather today'))
        38.5°C
    """
    def __init__(self, llm, intent_list: list = None,
                 *, prompt: str = '', constrain: str = '', attention: str = '',
                 examples: Optional[list[list[str, str]]] = None, return_trace: bool = False) -> None:
        super().__init__(return_trace=return_trace)
        self._intent_list = intent_list or []
        self._llm = llm
        self._prompt, self._constrain, self._attention, self._examples = prompt, constrain, attention, examples or []
        if self._intent_list:
            self._init()

    def _init(self):
        def choose_prompt():
            # Use chinese prompt if intent elements have chinese character, otherwise use english version
            for ele in self._intent_list:
                for ch in ele:
                    # chinese unicode range
                    if "\u4e00" <= ch <= "\u9fff":
                        return ch_prompt_classifier_template
            return en_prompt_classifier_template

        example_template = '\nUser: {{{{"human_input": "{inp}", "intent_list": {intent}}}}}\nAssistant: {label}\n'
        examples = ''.join([example_template.format(
            inp=input, intent=self._intent_list, label=label) for input, label in self._examples])
        prompt = choose_prompt().replace(
            '{user_prompt}', f' {self._prompt}').replace('{attention}', self._attention).replace(
            '{user_constrains}', f' {self._constrain}').replace('{user_examples}', f' {examples}')
        self._llm = self._llm.share(prompt=AlpacaPrompter(dict(system=prompt, user='${input}')
                                                          ).pre_hook(self.intent_promt_hook)).used_by(self._module_id)
        self._impl = pipeline(self._llm, self.post_process_result)

    def intent_promt_hook(
        self,
        input: Union[str, List, Dict[str, str], None] = None,
        history: List[Union[List[str], Dict[str, Any]]] = [],  # noqa B006
        tools: Union[List[Dict[str, Any]], None] = None,
        label: Union[str, None] = None,
    ):
        """意图分类的预处理 Hook。  
将输入文本与意图列表打包为 JSON,并生成历史对话信息字符串。

Args:
    input (str | List | Dict | None): 输入文本,仅支持字符串类型。
    history (List): 历史对话记录,默认为空列表。
    tools (List[Dict] | None): 工具信息,可选。
    label (str | None): 标签,可选。

**Returns**

- tuple: (输入数据字典, 历史记录列表, 工具信息, 标签)
"""
        input_json = {}
        if isinstance(input, str):
            input_json = {"human_input": input, "intent_list": self._intent_list}
        else:
            raise ValueError(f"Unexpected type for input: {type(input)}")

        history_info = chat_history_to_str(history)
        history = []
        input_text = json.dumps(input_json, ensure_ascii=False)
        return dict(history_info=history_info, input=input_text), history, tools, label

    def post_process_result(self, input):
        """意图分类结果的后处理。  
如果结果在意图列表中则直接返回,否则返回意图列表的第一个元素。

Args:
    input (str): 分类模型输出结果。

**Returns**

- str: 最终的分类标签。
"""
        input = input.strip()
        return input if input in self._intent_list else self._intent_list[0]

    def forward(self, input: str, llm_chat_history: List[Dict[str, Any]] = None):
        if llm_chat_history is not None and self._llm._module_id not in globals["chat_history"]:
            globals["chat_history"][self._llm._module_id] = llm_chat_history
        return self._impl(input)

    def __enter__(self):
        assert not self._intent_list, 'Intent list is already set'
        self._sw = switch()
        self._sw.__enter__()
        return self

    @property
    def case(self):
        return switch.Case(self)

    @property
    def submodules(self):
        submodule = []
        if isinstance(self._impl, switch):
            self._impl.for_each(lambda x: isinstance(x, ModuleBase), lambda x: submodule.append(x))
        return super().submodules + submodule

    # used by switch.Case
    def _add_case(self, cond, func):
        assert isinstance(cond, str), 'intent must be string'
        self._intent_list.append(cond)
        self._sw.case[cond, func]

    def __exit__(self, exc_type, exc_val, exc_tb):
        self._sw.__exit__(exc_type, exc_val, exc_tb)
        self._init()
        self._sw._set_conversion(self._impl)
        self._impl = self._sw

intent_promt_hook(input=None, history=[], tools=None, label=None)

意图分类的预处理 Hook。
将输入文本与意图列表打包为 JSON,并生成历史对话信息字符串。

Parameters:

  • input (str | List | Dict | None, default: None ) –

    输入文本,仅支持字符串类型。

  • history (List, default: [] ) –

    历史对话记录,默认为空列表。

  • tools (List[Dict] | None, default: None ) –

    工具信息,可选。

  • label (str | None, default: None ) –

    标签,可选。

Returns

  • tuple: (输入数据字典, 历史记录列表, 工具信息, 标签)
Source code in lazyllm/tools/classifier/intent_classifier.py
    def intent_promt_hook(
        self,
        input: Union[str, List, Dict[str, str], None] = None,
        history: List[Union[List[str], Dict[str, Any]]] = [],  # noqa B006
        tools: Union[List[Dict[str, Any]], None] = None,
        label: Union[str, None] = None,
    ):
        """意图分类的预处理 Hook。  
将输入文本与意图列表打包为 JSON,并生成历史对话信息字符串。

Args:
    input (str | List | Dict | None): 输入文本,仅支持字符串类型。
    history (List): 历史对话记录,默认为空列表。
    tools (List[Dict] | None): 工具信息,可选。
    label (str | None): 标签,可选。

**Returns**

- tuple: (输入数据字典, 历史记录列表, 工具信息, 标签)
"""
        input_json = {}
        if isinstance(input, str):
            input_json = {"human_input": input, "intent_list": self._intent_list}
        else:
            raise ValueError(f"Unexpected type for input: {type(input)}")

        history_info = chat_history_to_str(history)
        history = []
        input_text = json.dumps(input_json, ensure_ascii=False)
        return dict(history_info=history_info, input=input_text), history, tools, label

post_process_result(input)

意图分类结果的后处理。
如果结果在意图列表中则直接返回,否则返回意图列表的第一个元素。

Parameters:

  • input (str) –

    分类模型输出结果。

Returns

  • str: 最终的分类标签。
Source code in lazyllm/tools/classifier/intent_classifier.py
    def post_process_result(self, input):
        """意图分类结果的后处理。  
如果结果在意图列表中则直接返回,否则返回意图列表的第一个元素。

Args:
    input (str): 分类模型输出结果。

**Returns**

- str: 最终的分类标签。
"""
        input = input.strip()
        return input if input in self._intent_list else self._intent_list[0]

lazyllm.tools.Document

Bases: ModuleBase, BuiltinGroups

初始化一个具有可选用户界面的文档模块。

此构造函数初始化一个可以有或没有用户界面的文档模块。如果启用了用户界面,它还会提供一个ui界面来管理文档操作接口,并提供一个用于用户界面交互的网页。

Parameters:

  • dataset_path (str, default: None ) –

    数据集目录的路径。此目录应包含要由文档模块管理的文档。

  • embed (Optional[Union[Callable, Dict[str, Callable]]], default: None ) –

    用于生成文档 embedding 的对象。如果需要对文本生成多个 embedding,此处需要通过字典的方式指定多个 embedding 模型,key 标识 embedding 对应的名字, value 为对应的 embedding 模型。

  • create_ui (bool, default: False ) –

    [已弃用] 是否创建用户界面。请改用'manager'参数

  • manager (bool, default: False ) –

    指示是否为文档模块创建用户界面的标志。默认为 False。

  • server (Union[bool, int], default: False ) –

    服务器配置。True表示默认服务器,False表示已指定端口号作为自定义服务器

  • name (Optional[str], default: None ) –

    文档集合的名称标识符。云服务模式下必须提供

  • launcher (optional, default: None ) –

    负责启动服务器模块的对象或函数。如果未提供,则使用 lazyllm.launchers 中的默认异步启动器 (sync=False)。

  • doc_files (Optional[List[str]], default: None ) –

    临时文档文件列表(dataset_path的替代方案)。使用时dataset_path必须为None且仅支持map存储类型

  • store_conf (optional, default: None ) –

    配置使用哪种存储后端, 默认使用MapStore将切片数据存于内存中。

Examples:

>>> import lazyllm
>>> from lazyllm.tools import Document
>>> m = lazyllm.OnlineEmbeddingModule(source="glm")
>>> documents = Document(dataset_path='your_doc_path', embed=m, manager=False)  # or documents = Document(dataset_path='your_doc_path', embed={"key": m}, manager=False)
>>> m1 = lazyllm.TrainableModule("bge-large-zh-v1.5").start()
>>> document1 = Document(dataset_path='your_doc_path', embed={"online": m, "local": m1}, manager=False)
>>> store_conf = {
>>>     "segment_store": {
>>>         "type": "map",
>>>         "kwargs": {
>>>             "uri": "/tmp/tmp_segments.db",
>>>         },
>>>     },
>>>     "vector_store": {
>>>         "type": "milvus",
>>>         "kwargs": {
>>>             "uri": "/tmp/tmp_milvus.db",
>>>             "index_kwargs": {
>>>                 "index_type": "FLAT",
>>>                 "metric_type": "COSINE",
>>>             },
>>>         },
>>>     },
>>> }
>>> doc_fields = {
>>>     'author': DocField(data_type=DataType.VARCHAR, max_size=128, default_value=' '),
>>>     'public_year': DocField(data_type=DataType.INT32),
>>> }
>>> document2 = Document(dataset_path='your_doc_path', embed={"online": m, "local": m1}, store_conf=store_conf, doc_fields=doc_fields)
Source code in lazyllm/tools/rag/document.py
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
class Document(ModuleBase, BuiltinGroups, metaclass=_MetaDocument):
    """初始化一个具有可选用户界面的文档模块。

此构造函数初始化一个可以有或没有用户界面的文档模块。如果启用了用户界面,它还会提供一个ui界面来管理文档操作接口,并提供一个用于用户界面交互的网页。

Args:
    dataset_path (str): 数据集目录的路径。此目录应包含要由文档模块管理的文档。
    embed (Optional[Union[Callable, Dict[str, Callable]]]): 用于生成文档 embedding 的对象。如果需要对文本生成多个 embedding,此处需要通过字典的方式指定多个 embedding 模型,key 标识 embedding 对应的名字, value 为对应的 embedding 模型。
    create_ui (bool):[已弃用] 是否创建用户界面。请改用'manager'参数
    manager (bool, optional): 指示是否为文档模块创建用户界面的标志。默认为 False。
    server (Union[bool, int]):服务器配置。True表示默认服务器,False表示已指定端口号作为自定义服务器
    name (Optional[str]):文档集合的名称标识符。云服务模式下必须提供
    launcher (optional): 负责启动服务器模块的对象或函数。如果未提供,则使用 `lazyllm.launchers` 中的默认异步启动器 (`sync=False`)。            
    doc_files (Optional[List[str]]):临时文档文件列表(dataset_path的替代方案)。使用时dataset_path必须为None且仅支持map存储类型
    store_conf (optional): 配置使用哪种存储后端, 默认使用MapStore将切片数据存于内存中。


Examples:
    >>> import lazyllm
    >>> from lazyllm.tools import Document
    >>> m = lazyllm.OnlineEmbeddingModule(source="glm")
    >>> documents = Document(dataset_path='your_doc_path', embed=m, manager=False)  # or documents = Document(dataset_path='your_doc_path', embed={"key": m}, manager=False)
    >>> m1 = lazyllm.TrainableModule("bge-large-zh-v1.5").start()
    >>> document1 = Document(dataset_path='your_doc_path', embed={"online": m, "local": m1}, manager=False)

    >>> store_conf = {
    >>>     "segment_store": {
    >>>         "type": "map",
    >>>         "kwargs": {
    >>>             "uri": "/tmp/tmp_segments.db",
    >>>         },
    >>>     },
    >>>     "vector_store": {
    >>>         "type": "milvus",
    >>>         "kwargs": {
    >>>             "uri": "/tmp/tmp_milvus.db",
    >>>             "index_kwargs": {
    >>>                 "index_type": "FLAT",
    >>>                 "metric_type": "COSINE",
    >>>             },
    >>>         },
    >>>     },
    >>> }
    >>> doc_fields = {
    >>>     'author': DocField(data_type=DataType.VARCHAR, max_size=128, default_value=' '),
    >>>     'public_year': DocField(data_type=DataType.INT32),
    >>> }
    >>> document2 = Document(dataset_path='your_doc_path', embed={"online": m, "local": m1}, store_conf=store_conf, doc_fields=doc_fields)
    """
    class _Manager(ModuleBase):
        def __init__(self, dataset_path: Optional[str], embed: Optional[Union[Callable, Dict[str, Callable]]] = None,
                     manager: Union[bool, str] = False, server: Union[bool, int] = False, name: Optional[str] = None,
                     launcher: Optional[Launcher] = None, store_conf: Optional[Dict] = None,
                     doc_fields: Optional[Dict[str, DocField]] = None, cloud: bool = False,
                     doc_files: Optional[List[str]] = None, processor: Optional[DocumentProcessor] = None,
                     display_name: Optional[str] = "", description: Optional[str] = "algorithm description"):
            super().__init__()
            self._origin_path, self._doc_files, self._cloud = dataset_path, doc_files, cloud

            if dataset_path and not os.path.exists(dataset_path):
                defatult_path = os.path.join(lazyllm.config["data_path"], dataset_path)
                if os.path.exists(defatult_path):
                    dataset_path = defatult_path
            elif dataset_path:
                dataset_path = os.path.join(os.getcwd(), dataset_path)

            self._launcher: Launcher = launcher if launcher else lazyllm.launchers.remote(sync=False)
            self._dataset_path = dataset_path
            self._embed = self._get_embeds(embed)
            self._processor = processor
            name = name or DocListManager.DEFAULT_GROUP_NAME
            if not display_name: display_name = name

            self._dlm = None if (self._cloud or self._doc_files is not None) else DocListManager(
                dataset_path, name, enable_path_monitoring=False if manager else True)
            self._kbs = CallableDict({name: DocImpl(
                embed=self._embed, dlm=self._dlm, doc_files=doc_files, global_metadata_desc=doc_fields,
                store=store_conf, processor=processor, algo_name=name, display_name=display_name,
                description=description)})

            if manager: self._manager = ServerModule(DocManager(self._dlm), launcher=self._launcher)
            if manager == 'ui': self._docweb = DocWebModule(doc_server=self._manager)
            if server: self._kbs = ServerModule(self._kbs, port=(None if isinstance(server, bool) else int(server)))
            self._global_metadata_desc = doc_fields

        @property
        def url(self):
            if hasattr(self, '_manager'): return self._manager._url
            return None

        @property
        @deprecated('Document.manager.url')
        def _url(self):
            return self.url

        @property
        def web_url(self):
            if hasattr(self, '_docweb'): return self._docweb.url
            return None

        def _get_embeds(self, embed):
            embeds = embed if isinstance(embed, dict) else {EMBED_DEFAULT_KEY: embed} if embed else {}
            for embed in embeds.values():
                if isinstance(embed, ModuleBase):
                    self._submodules.append(embed)
            return embeds

        def add_kb_group(self, name, doc_fields: Optional[Dict[str, DocField]] = None,
                         store_conf: Optional[Dict] = None,
                         embed: Optional[Union[Callable, Dict[str, Callable]]] = None):
            embed = self._get_embeds(embed) if embed else self._embed
            if isinstance(self._kbs, ServerModule):
                self._kbs._impl._m[name] = DocImpl(dlm=self._dlm, embed=embed, kb_group_name=name,
                                                   global_metadata_desc=doc_fields, store=store_conf)
            else:
                self._kbs[name] = DocImpl(dlm=self._dlm, embed=self._embed, kb_group_name=name,
                                          global_metadata_desc=doc_fields, store=store_conf)
            self._dlm.add_kb_group(name=name)

        def get_doc_by_kb_group(self, name):
            return self._kbs._impl._m[name] if isinstance(self._kbs, ServerModule) else self._kbs[name]

        def stop(self):
            if hasattr(self, '_docweb'):
                self._docweb.stop()
            self._launcher.cleanup()

        def __call__(self, *args, **kw):
            return self._kbs(*args, **kw)

    def __new__(cls, *args, **kw):
        if url := kw.pop('url', None):
            name = kw.pop('name', None)
            assert not args and not kw, 'Only `name` is supported with `url`'
            return UrlDocument(url, name)
        else:
            return super().__new__(cls)

    def __init__(self, dataset_path: Optional[str] = None, embed: Optional[Union[Callable, Dict[str, Callable]]] = None,
                 create_ui: bool = False, manager: Union[bool, str, "Document._Manager", DocumentProcessor] = False,
                 server: Union[bool, int] = False, name: Optional[str] = None, launcher: Optional[Launcher] = None,
                 doc_files: Optional[List[str]] = None, doc_fields: Dict[str, DocField] = None,
                 store_conf: Optional[Dict] = None, display_name: Optional[str] = "",
                 description: Optional[str] = "algorithm description"):
        super().__init__()
        if create_ui:
            lazyllm.LOG.warning('`create_ui` for Document is deprecated, use `manager` instead')
            manager = create_ui
        if isinstance(dataset_path, (tuple, list)):
            doc_fields = dataset_path
            dataset_path = None
        if doc_files is not None:
            assert dataset_path is None and not manager, (
                'Manager and dataset_path are not supported for Document with temp-files')
            assert store_conf is None or store_conf['type'] == 'map', (
                'Only map store is supported for Document with temp-files')

        name = name or DocListManager.DEFAULT_GROUP_NAME

        if isinstance(manager, Document._Manager):
            assert not server, 'Server infomation is already set to by manager'
            assert not launcher, 'Launcher infomation is already set to by manager'
            assert not manager._cloud, 'manager is not allowed to share in cloud mode'
            assert manager._doc_files is None, 'manager is not allowed to share with temp files'
            if dataset_path != manager._dataset_path and dataset_path != manager._origin_path:
                raise RuntimeError(f'Document path mismatch, expected `{manager._dataset_path}`'
                                   f'while received `{dataset_path}`')
            manager.add_kb_group(name=name, doc_fields=doc_fields, store_conf=store_conf, embed=embed)
            self._manager = manager
            self._curr_group = name
        else:
            if isinstance(manager, DocumentProcessor):
                processor, cloud = manager, True
                processor._impl.start()
                manager = False
                assert name, '`Name` of Document is necessary when using cloud service'
                assert store_conf.get('type') != 'map', 'Cloud manager is not supported when using map store'
                assert not dataset_path, 'Cloud manager is not supported with local dataset path'
            else:
                cloud, processor = False, None
            self._manager = Document._Manager(dataset_path, embed, manager, server, name, launcher, store_conf,
                                              doc_fields, cloud=cloud, doc_files=doc_files, processor=processor,
                                              display_name=display_name, description=description)
            self._curr_group = name
        self._doc_to_db_processor: DocToDbProcessor = None

    def _list_all_files_in_dataset(self) -> List[str]:
        files_list = []
        for root, _, files in os.walk(self._manager._dataset_path):
            files = [os.path.join(root, file_path) for file_path in files]
            files_list.extend(files)
        return files_list

    def connect_sql_manager(
        self,
        sql_manager: SqlManager,
        schma: Optional[DocInfoSchema] = None,
        force_refresh: bool = True,
    ):
        def format_schema_to_dict(schema: DocInfoSchema):
            if schema is None:
                return None, None
            desc_dict = {ele["key"]: ele["desc"] for ele in schema}
            type_dict = {ele["key"]: ele["type"] for ele in schema}
            return desc_dict, type_dict

        def compare_schema(old_schema: DocInfoSchema, new_schema: DocInfoSchema):
            old_desc_dict, old_type_dict = format_schema_to_dict(old_schema)
            new_desc_dict, new_type_dict = format_schema_to_dict(new_schema)
            return old_desc_dict == new_desc_dict and old_type_dict == new_type_dict

        # 1. Check valid arguments
        if sql_manager.check_connection().status != DBStatus.SUCCESS:
            raise RuntimeError(f'Failed to connect to sql manager: {sql_manager._gen_conn_url()}')
        pre_doc_table_schema = None
        if self._doc_to_db_processor:
            pre_doc_table_schema = self._doc_to_db_processor.doc_info_schema
        assert pre_doc_table_schema or schma, "doc_table_schma must be given"

        schema_equal = compare_schema(pre_doc_table_schema, schma)
        assert (
            schema_equal or force_refresh is True
        ), "When changing doc_table_schema, force_refresh should be set to True"

        # 2. Init handler if needed
        need_init_processor = False
        if self._doc_to_db_processor is None:
            need_init_processor = True
        else:
            # avoid reinit for the same db
            if sql_manager != self._doc_to_db_processor.sql_manager:
                need_init_processor = True
        if need_init_processor:
            self._doc_to_db_processor = DocToDbProcessor(sql_manager)

        # 3. Reset doc_table_schema if needed
        if schma and not schema_equal:
            # This api call will clear existing db table "lazyllm_doc_elements"
            self._doc_to_db_processor._reset_doc_info_schema(schma)

    def get_sql_manager(self):
        if self._doc_to_db_processor is None:
            raise ValueError("Please call connect_sql_manager to init handler first")
        return self._doc_to_db_processor.sql_manager

    def extract_db_schema(
        self, llm: Union[OnlineChatModule, TrainableModule], print_schema: bool = False
    ) -> DocInfoSchema:
        file_paths = self._list_all_files_in_dataset()
        schema = extract_db_schema_from_files(file_paths, llm)
        if print_schema:
            lazyllm.LOG.info(f"Extracted Schema:\n\t{schema}\n")
        return schema

    def update_database(self, llm: Union[OnlineChatModule, TrainableModule]):
        assert self._doc_to_db_processor, "Please call connect_db to init handler first"
        file_paths = self._list_all_files_in_dataset()
        info_dicts = self._doc_to_db_processor.extract_info_from_docs(llm, file_paths)
        self._doc_to_db_processor.export_info_to_db(info_dicts)

    @deprecated('Document(dataset_path, manager=doc.manager, name=xx, doc_fields=xx, store_conf=xx)')
    def create_kb_group(self, name: str, doc_fields: Optional[Dict[str, DocField]] = None,
                        store_conf: Optional[Dict] = None) -> "Document":
        self._manager.add_kb_group(name=name, doc_fields=doc_fields, store_conf=store_conf)
        doc = copy.copy(self)
        doc._curr_group = name
        return doc

    @property
    @deprecated('Document._manager')
    def _impls(self): return self._manager

    @property
    def _impl(self) -> DocImpl: return self._manager.get_doc_by_kb_group(self._curr_group)

    @property
    def manager(self): return self._manager._processor or self._manager

    def activate_group(self, group_name: str, embed_keys: Optional[Union[str, List[str]]] = None):
        if isinstance(embed_keys, str): embed_keys = [embed_keys]
        elif embed_keys is None: embed_keys = []
        self._impl.activate_group(group_name, embed_keys)

    def activate_groups(self, groups: Union[str, List[str]]):
        if isinstance(groups, str): groups = [groups]
        for group in groups:
            self.activate_group(group)

    @DynamicDescriptor
    def create_node_group(self, name: str = None, *, transform: Callable, parent: str = LAZY_ROOT_NAME,
                          trans_node: bool = None, num_workers: int = 0, display_name: str = None,
                          group_type: NodeGroupType = NodeGroupType.CHUNK, **kwargs) -> None:
        """
创建一个由指定规则生成的 node group。

Args:
    name (str): node group 的名称。
    transform (Callable): 将 node 转换成 node group 的转换规则,函数原型是 `(DocNode, group_name, **kwargs) -> List[DocNode]`。目前内置的有 [SentenceSplitter][lazyllm.tools.SentenceSplitter]。用户也可以自定义转换规则。
    trans_node (bool): 决定了transform的输入和输出是 `DocNode` 还是 `str` ,默认为None。只有在 `transform` 为 `Callable` 时才可以设置为true。
    num_workers (int): Transform时所用的新线程数量,默认为0
    parent (str): 需要进一步转换的节点。转换之后得到的一系列新的节点将会作为该父节点的子节点。如果不指定则从根节点开始转换。
    kwargs: 和具体实现相关的参数。


Examples:

    >>> import lazyllm
    >>> from lazyllm.tools import Document, SentenceSplitter
    >>> m = lazyllm.OnlineEmbeddingModule(source="glm")
    >>> documents = Document(dataset_path='your_doc_path', embed=m, manager=False)
    >>> documents.create_node_group(name="sentences", transform=SentenceSplitter, chunk_size=1024, chunk_overlap=100)
    """
        if isinstance(self, type):
            DocImpl.create_global_node_group(name, transform=transform, parent=parent, trans_node=trans_node,
                                             num_workers=num_workers, display_name=display_name,
                                             group_type=group_type, **kwargs)
        else:
            self._impl.create_node_group(name, transform=transform, parent=parent, trans_node=trans_node,
                                         num_workers=num_workers, display_name=display_name, group_type=group_type,
                                         **kwargs)

    @DynamicDescriptor
    def add_reader(self, pattern: str, func: Optional[Callable] = None):
        """
用于实例指定文件读取器,作用范围仅对注册的 Document 对象可见。注册的文件读取器必须是 Callable 对象。只能通过函数调用的方式进行注册。并且通过实例注册的文件读取器的优先级高于通过类注册的文件读取器,并且实例和类注册的文件读取器的优先级高于系统默认的文件读取器。即优先级的顺序是:实例文件读取器 > 类文件读取器 > 系统默认文件读取器。

Args:
    pattern (str): 文件读取器适用的匹配规则
    func (Callable): 文件读取器,必须是Callable的对象


Examples:

    >>> from lazyllm.tools.rag import Document, DocNode
    >>> from lazyllm.tools.rag.readers import ReaderBase
    >>> class YmlReader(ReaderBase):
    ...     def _load_data(self, file, fs=None):
    ...         try:
    ...             import yaml
    ...         except ImportError:
    ...             raise ImportError("yaml is required to read YAML file: `pip install pyyaml`")
    ...         with open(file, 'r') as f:
    ...             data = yaml.safe_load(f)
    ...         print("Call the class YmlReader.")
    ...         return [DocNode(text=data)]
    ...
    >>> def processYml(file):
    ...     with open(file, 'r') as f:
    ...         data = f.read()
    ...     print("Call the function processYml.")
    ...     return [DocNode(text=data)]
    ...
    >>> doc1 = Document(dataset_path="your_files_path", create_ui=False)
    >>> doc2 = Document(dataset_path="your_files_path", create_ui=False)
    >>> doc1.add_reader("**/*.yml", YmlReader)
    >>> print(doc1._impl._local_file_reader)
    {'**/*.yml': <class '__main__.YmlReader'>}
    >>> print(doc2._impl._local_file_reader)
    {}
    >>> files = ["your_yml_files"]
    >>> Document.register_global_reader("**/*.yml", processYml)
    >>> doc1._impl._reader.load_data(input_files=files)
    Call the class YmlReader.
    >>> doc2._impl._reader.load_data(input_files=files)
    Call the function processYml.
    """
        if isinstance(self, type):
            return DocImpl.register_global_reader(pattern=pattern, func=func)
        else:
            self._impl.add_reader(pattern, func)

    @classmethod
    def register_global_reader(cls, pattern: str, func: Optional[Callable] = None):
        """
用于指定文件读取器,作用范围对于所有的 Document 对象都可见。注册的文件读取器必须是 Callable 对象。可以使用装饰器的方式进行注册,也可以通过函数调用的方式进行注册。

Args:
    pattern (str): 文件读取器适用的匹配规则
    func (Callable): 文件读取器,必须是Callable的对象


Examples:

    >>> from lazyllm.tools.rag import Document, DocNode
    >>> @Document.register_global_reader("**/*.yml")
    >>> def processYml(file):
    ...     with open(file, 'r') as f:
    ...         data = f.read()
    ...     return [DocNode(text=data)]
    ...
    >>> doc1 = Document(dataset_path="your_files_path", create_ui=False)
    >>> doc2 = Document(dataset_path="your_files_path", create_ui=False)
    >>> files = ["your_yml_files"]
    >>> docs1 = doc1._impl._reader.load_data(input_files=files)
    >>> docs2 = doc2._impl._reader.load_data(input_files=files)
    >>> print(docs1[0].text == docs2[0].text)
    # True
    """
        return cls.add_reader(pattern, func)

    def get_store(self):
        return StorePlaceholder()

    def get_embed(self):
        return EmbedPlaceholder()

    def register_index(self, index_type: str, index_cls: IndexBase, *args, **kwargs) -> None:
        self._impl.register_index(index_type, index_cls, *args, **kwargs)

    def _forward(self, func_name: str, *args, **kw):
        return self._manager(self._curr_group, func_name, *args, **kw)

    def find_parent(self, target) -> Callable:
        """
查找指定节点的父节点。

Args:
    group (str): 需要查找的节点组名称


Examples:

    >>> import lazyllm
    >>> from lazyllm.tools import Document, SentenceSplitter
    >>> m = lazyllm.OnlineEmbeddingModule(source="glm")
    >>> documents = Document(dataset_path='your_doc_path', embed=m, manager=False)
    >>> documents.create_node_group(name="parent", transform=SentenceSplitter, chunk_size=1024, chunk_overlap=100)
    >>> documents.create_node_group(name="children", transform=SentenceSplitter, parent="parent", chunk_size=1024, chunk_overlap=100)
    >>> documents.find_parent('children')
    """
        return functools.partial(self._forward, 'find_parent', group=target)

    def find_children(self, target) -> Callable:
        """
查找指定节点的子节点。

Args:
    group (str): 需要查找的节点组名称


Examples:

    >>> import lazyllm
    >>> from lazyllm.tools import Document, SentenceSplitter
    >>> m = lazyllm.OnlineEmbeddingModule(source="glm")
    >>> documents = Document(dataset_path='your_doc_path', embed=m, manager=False)
    >>> documents.create_node_group(name="parent", transform=SentenceSplitter, chunk_size=1024, chunk_overlap=100)
    >>> documents.create_node_group(name="children", transform=SentenceSplitter, parent="parent", chunk_size=1024, chunk_overlap=100)
    >>> documents.find_children('parent')
    """
        return functools.partial(self._forward, 'find_children', group=target)

    def find(self, target) -> Callable:
        return functools.partial(self._forward, 'find', group=target)

    def forward(self, *args, **kw) -> List[DocNode]:
        return self._forward('retrieve', *args, **kw)

    def clear_cache(self, group_names: Optional[List[str]] = None) -> None:
        return self._forward('clear_cache', group_names)

    def _get_post_process_tasks(self):
        return lazyllm.pipeline(lambda *a: self._forward('_lazy_init'))

    def __repr__(self):
        return lazyllm.make_repr("Module", "Document", manager=hasattr(self._manager, '_manager'),
                                 server=isinstance(self._manager._kbs, ServerModule))

add_reader(pattern, func=None)

用于实例指定文件读取器,作用范围仅对注册的 Document 对象可见。注册的文件读取器必须是 Callable 对象。只能通过函数调用的方式进行注册。并且通过实例注册的文件读取器的优先级高于通过类注册的文件读取器,并且实例和类注册的文件读取器的优先级高于系统默认的文件读取器。即优先级的顺序是:实例文件读取器 > 类文件读取器 > 系统默认文件读取器。

Parameters:

  • pattern (str) –

    文件读取器适用的匹配规则

  • func (Callable, default: None ) –

    文件读取器,必须是Callable的对象

Examples:

>>> from lazyllm.tools.rag import Document, DocNode
>>> from lazyllm.tools.rag.readers import ReaderBase
>>> class YmlReader(ReaderBase):
...     def _load_data(self, file, fs=None):
...         try:
...             import yaml
...         except ImportError:
...             raise ImportError("yaml is required to read YAML file: `pip install pyyaml`")
...         with open(file, 'r') as f:
...             data = yaml.safe_load(f)
...         print("Call the class YmlReader.")
...         return [DocNode(text=data)]
...
>>> def processYml(file):
...     with open(file, 'r') as f:
...         data = f.read()
...     print("Call the function processYml.")
...     return [DocNode(text=data)]
...
>>> doc1 = Document(dataset_path="your_files_path", create_ui=False)
>>> doc2 = Document(dataset_path="your_files_path", create_ui=False)
>>> doc1.add_reader("**/*.yml", YmlReader)
>>> print(doc1._impl._local_file_reader)
{'**/*.yml': <class '__main__.YmlReader'>}
>>> print(doc2._impl._local_file_reader)
{}
>>> files = ["your_yml_files"]
>>> Document.register_global_reader("**/*.yml", processYml)
>>> doc1._impl._reader.load_data(input_files=files)
Call the class YmlReader.
>>> doc2._impl._reader.load_data(input_files=files)
Call the function processYml.
Source code in lazyllm/tools/rag/document.py
    @DynamicDescriptor
    def add_reader(self, pattern: str, func: Optional[Callable] = None):
        """
用于实例指定文件读取器,作用范围仅对注册的 Document 对象可见。注册的文件读取器必须是 Callable 对象。只能通过函数调用的方式进行注册。并且通过实例注册的文件读取器的优先级高于通过类注册的文件读取器,并且实例和类注册的文件读取器的优先级高于系统默认的文件读取器。即优先级的顺序是:实例文件读取器 > 类文件读取器 > 系统默认文件读取器。

Args:
    pattern (str): 文件读取器适用的匹配规则
    func (Callable): 文件读取器,必须是Callable的对象


Examples:

    >>> from lazyllm.tools.rag import Document, DocNode
    >>> from lazyllm.tools.rag.readers import ReaderBase
    >>> class YmlReader(ReaderBase):
    ...     def _load_data(self, file, fs=None):
    ...         try:
    ...             import yaml
    ...         except ImportError:
    ...             raise ImportError("yaml is required to read YAML file: `pip install pyyaml`")
    ...         with open(file, 'r') as f:
    ...             data = yaml.safe_load(f)
    ...         print("Call the class YmlReader.")
    ...         return [DocNode(text=data)]
    ...
    >>> def processYml(file):
    ...     with open(file, 'r') as f:
    ...         data = f.read()
    ...     print("Call the function processYml.")
    ...     return [DocNode(text=data)]
    ...
    >>> doc1 = Document(dataset_path="your_files_path", create_ui=False)
    >>> doc2 = Document(dataset_path="your_files_path", create_ui=False)
    >>> doc1.add_reader("**/*.yml", YmlReader)
    >>> print(doc1._impl._local_file_reader)
    {'**/*.yml': <class '__main__.YmlReader'>}
    >>> print(doc2._impl._local_file_reader)
    {}
    >>> files = ["your_yml_files"]
    >>> Document.register_global_reader("**/*.yml", processYml)
    >>> doc1._impl._reader.load_data(input_files=files)
    Call the class YmlReader.
    >>> doc2._impl._reader.load_data(input_files=files)
    Call the function processYml.
    """
        if isinstance(self, type):
            return DocImpl.register_global_reader(pattern=pattern, func=func)
        else:
            self._impl.add_reader(pattern, func)

create_node_group(name=None, *, transform, parent=LAZY_ROOT_NAME, trans_node=None, num_workers=0, display_name=None, group_type=NodeGroupType.CHUNK, **kwargs)

创建一个由指定规则生成的 node group。

Parameters:

  • name (str, default: None ) –

    node group 的名称。

  • transform (Callable) –

    将 node 转换成 node group 的转换规则,函数原型是 (DocNode, group_name, **kwargs) -> List[DocNode]。目前内置的有 SentenceSplitter。用户也可以自定义转换规则。

  • trans_node (bool, default: None ) –

    决定了transform的输入和输出是 DocNode 还是 str ,默认为None。只有在 transformCallable 时才可以设置为true。

  • num_workers (int, default: 0 ) –

    Transform时所用的新线程数量,默认为0

  • parent (str, default: LAZY_ROOT_NAME ) –

    需要进一步转换的节点。转换之后得到的一系列新的节点将会作为该父节点的子节点。如果不指定则从根节点开始转换。

  • kwargs

    和具体实现相关的参数。

Examples:

>>> import lazyllm
>>> from lazyllm.tools import Document, SentenceSplitter
>>> m = lazyllm.OnlineEmbeddingModule(source="glm")
>>> documents = Document(dataset_path='your_doc_path', embed=m, manager=False)
>>> documents.create_node_group(name="sentences", transform=SentenceSplitter, chunk_size=1024, chunk_overlap=100)
Source code in lazyllm/tools/rag/document.py
    @DynamicDescriptor
    def create_node_group(self, name: str = None, *, transform: Callable, parent: str = LAZY_ROOT_NAME,
                          trans_node: bool = None, num_workers: int = 0, display_name: str = None,
                          group_type: NodeGroupType = NodeGroupType.CHUNK, **kwargs) -> None:
        """
创建一个由指定规则生成的 node group。

Args:
    name (str): node group 的名称。
    transform (Callable): 将 node 转换成 node group 的转换规则,函数原型是 `(DocNode, group_name, **kwargs) -> List[DocNode]`。目前内置的有 [SentenceSplitter][lazyllm.tools.SentenceSplitter]。用户也可以自定义转换规则。
    trans_node (bool): 决定了transform的输入和输出是 `DocNode` 还是 `str` ,默认为None。只有在 `transform` 为 `Callable` 时才可以设置为true。
    num_workers (int): Transform时所用的新线程数量,默认为0
    parent (str): 需要进一步转换的节点。转换之后得到的一系列新的节点将会作为该父节点的子节点。如果不指定则从根节点开始转换。
    kwargs: 和具体实现相关的参数。


Examples:

    >>> import lazyllm
    >>> from lazyllm.tools import Document, SentenceSplitter
    >>> m = lazyllm.OnlineEmbeddingModule(source="glm")
    >>> documents = Document(dataset_path='your_doc_path', embed=m, manager=False)
    >>> documents.create_node_group(name="sentences", transform=SentenceSplitter, chunk_size=1024, chunk_overlap=100)
    """
        if isinstance(self, type):
            DocImpl.create_global_node_group(name, transform=transform, parent=parent, trans_node=trans_node,
                                             num_workers=num_workers, display_name=display_name,
                                             group_type=group_type, **kwargs)
        else:
            self._impl.create_node_group(name, transform=transform, parent=parent, trans_node=trans_node,
                                         num_workers=num_workers, display_name=display_name, group_type=group_type,
                                         **kwargs)

find_children(target)

查找指定节点的子节点。

Parameters:

  • group (str) –

    需要查找的节点组名称

Examples:

>>> import lazyllm
>>> from lazyllm.tools import Document, SentenceSplitter
>>> m = lazyllm.OnlineEmbeddingModule(source="glm")
>>> documents = Document(dataset_path='your_doc_path', embed=m, manager=False)
>>> documents.create_node_group(name="parent", transform=SentenceSplitter, chunk_size=1024, chunk_overlap=100)
>>> documents.create_node_group(name="children", transform=SentenceSplitter, parent="parent", chunk_size=1024, chunk_overlap=100)
>>> documents.find_children('parent')
Source code in lazyllm/tools/rag/document.py
    def find_children(self, target) -> Callable:
        """
查找指定节点的子节点。

Args:
    group (str): 需要查找的节点组名称


Examples:

    >>> import lazyllm
    >>> from lazyllm.tools import Document, SentenceSplitter
    >>> m = lazyllm.OnlineEmbeddingModule(source="glm")
    >>> documents = Document(dataset_path='your_doc_path', embed=m, manager=False)
    >>> documents.create_node_group(name="parent", transform=SentenceSplitter, chunk_size=1024, chunk_overlap=100)
    >>> documents.create_node_group(name="children", transform=SentenceSplitter, parent="parent", chunk_size=1024, chunk_overlap=100)
    >>> documents.find_children('parent')
    """
        return functools.partial(self._forward, 'find_children', group=target)

find_parent(target)

查找指定节点的父节点。

Parameters:

  • group (str) –

    需要查找的节点组名称

Examples:

>>> import lazyllm
>>> from lazyllm.tools import Document, SentenceSplitter
>>> m = lazyllm.OnlineEmbeddingModule(source="glm")
>>> documents = Document(dataset_path='your_doc_path', embed=m, manager=False)
>>> documents.create_node_group(name="parent", transform=SentenceSplitter, chunk_size=1024, chunk_overlap=100)
>>> documents.create_node_group(name="children", transform=SentenceSplitter, parent="parent", chunk_size=1024, chunk_overlap=100)
>>> documents.find_parent('children')
Source code in lazyllm/tools/rag/document.py
    def find_parent(self, target) -> Callable:
        """
查找指定节点的父节点。

Args:
    group (str): 需要查找的节点组名称


Examples:

    >>> import lazyllm
    >>> from lazyllm.tools import Document, SentenceSplitter
    >>> m = lazyllm.OnlineEmbeddingModule(source="glm")
    >>> documents = Document(dataset_path='your_doc_path', embed=m, manager=False)
    >>> documents.create_node_group(name="parent", transform=SentenceSplitter, chunk_size=1024, chunk_overlap=100)
    >>> documents.create_node_group(name="children", transform=SentenceSplitter, parent="parent", chunk_size=1024, chunk_overlap=100)
    >>> documents.find_parent('children')
    """
        return functools.partial(self._forward, 'find_parent', group=target)

register_global_reader(pattern, func=None) classmethod

用于指定文件读取器,作用范围对于所有的 Document 对象都可见。注册的文件读取器必须是 Callable 对象。可以使用装饰器的方式进行注册,也可以通过函数调用的方式进行注册。

Parameters:

  • pattern (str) –

    文件读取器适用的匹配规则

  • func (Callable, default: None ) –

    文件读取器,必须是Callable的对象

Examples:

>>> from lazyllm.tools.rag import Document, DocNode
>>> @Document.register_global_reader("**/*.yml")
>>> def processYml(file):
...     with open(file, 'r') as f:
...         data = f.read()
...     return [DocNode(text=data)]
...
>>> doc1 = Document(dataset_path="your_files_path", create_ui=False)
>>> doc2 = Document(dataset_path="your_files_path", create_ui=False)
>>> files = ["your_yml_files"]
>>> docs1 = doc1._impl._reader.load_data(input_files=files)
>>> docs2 = doc2._impl._reader.load_data(input_files=files)
>>> print(docs1[0].text == docs2[0].text)
# True
Source code in lazyllm/tools/rag/document.py
    @classmethod
    def register_global_reader(cls, pattern: str, func: Optional[Callable] = None):
        """
用于指定文件读取器,作用范围对于所有的 Document 对象都可见。注册的文件读取器必须是 Callable 对象。可以使用装饰器的方式进行注册,也可以通过函数调用的方式进行注册。

Args:
    pattern (str): 文件读取器适用的匹配规则
    func (Callable): 文件读取器,必须是Callable的对象


Examples:

    >>> from lazyllm.tools.rag import Document, DocNode
    >>> @Document.register_global_reader("**/*.yml")
    >>> def processYml(file):
    ...     with open(file, 'r') as f:
    ...         data = f.read()
    ...     return [DocNode(text=data)]
    ...
    >>> doc1 = Document(dataset_path="your_files_path", create_ui=False)
    >>> doc2 = Document(dataset_path="your_files_path", create_ui=False)
    >>> files = ["your_yml_files"]
    >>> docs1 = doc1._impl._reader.load_data(input_files=files)
    >>> docs2 = doc2._impl._reader.load_data(input_files=files)
    >>> print(docs1[0].text == docs2[0].text)
    # True
    """
        return cls.add_reader(pattern, func)

lazyllm.tools.rag.store.ChromadbStore

Bases: LazyLLMStoreBase

Source code in lazyllm/tools/rag/store/vector/chroma_store.py
class ChromadbStore(LazyLLMStoreBase):
    capability = StoreCapability.VECTOR
    need_embedding = True
    supports_index_registration = False

    def __init__(self, uri: Optional[str] = None, dir: Optional[str] = None,
                 index_kwargs: Optional[Union[Dict, List]] = None, client_kwargs: Optional[Dict] = None,
                 **kwargs) -> None:
        assert uri or (dir), "uri or dir must be provided"
        self._index_kwargs = index_kwargs or DEFAULT_INDEX_CONFIG
        self._client_kwargs = client_kwargs or {}
        if dir:
            self._dir = dir
        else:
            self._dir, self._host, self._port = self._parse_uri(uri)
        self._primary_key = 'uid'

    @property
    def dir(self):
        if not self._dir: return None
        p = Path(self._dir)
        p = p if p.suffix else (p / "chroma.sqlite3")
        return str(p.resolve(strict=False))

    def _parse_uri(self, uri: str):
        windows_drive = re.match(r"^[a-zA-Z]:[\\/]", uri or "")
        if ("://" not in uri) and (windows_drive or os.path.isabs(uri)):
            return os.path.abspath(uri), None, None

        p = urlparse(uri)

        if p.scheme == "":
            return os.path.abspath(uri), None, None

        if p.scheme == "file":
            path = p.path
            if os.name == "nt" and path.startswith("/") and re.match(r"^/[a-zA-Z]:", path):
                path = path.lstrip("/")  # file:///C:/... -> C:/...
            return os.path.abspath(path), None, None

        scheme = p.scheme
        if scheme.startswith("chroma+"):
            scheme = scheme.split("+", 1)[1]  # http or https

        if scheme in ("http", "https"):
            host = p.hostname or "127.0.0.1"
            port = p.port or (443 if scheme == "https" else 80)
            return None, host, port

        raise ValueError(f"Unsupported URI scheme in '{uri}'. "
                         "Use file:///path or plain path for local; http(s)://host:port for remote.")

    @override
    def connect(self, embed_dims: Optional[Dict[str, int]] = None,
                embed_datatypes: Optional[Dict[str, DataType]] = None,
                global_metadata_desc: Optional[Dict[str, GlobalMetadataDesc]] = None, **kwargs):
        self._global_metadata_desc = global_metadata_desc or {}
        self._embed_dims = embed_dims or {}
        self._embed_datatypes = embed_datatypes or {}
        for k, v in self._global_metadata_desc.items():
            if v.data_type not in [DataType.VARCHAR, DataType.INT32, DataType.FLOAT, DataType.BOOLEAN]:
                raise ValueError(f"[Chromadb Store] Unsupported data type {v.data_type} for global metadata {k}"
                                 " (only string, int, float, bool are supported)")
        for k, v in self._embed_datatypes.items():
            if v not in [DataType.FLOAT_VECTOR, DataType.SPARSE_FLOAT_VECTOR]:
                raise ValueError(f"[Chromadb Store] Unsupported data type {v} for embed key {k}"
                                 " (only float vector and sparse float vector are supported)")
        if self._dir:
            self._client = chromadb.PersistentClient(path=self._dir, **self._client_kwargs)
            LOG.success(f"Initialzed chromadb in path: {self._dir}")
        else:
            self._client = chromadb.HttpClient(host=self._host, port=self._port, **self._client_kwargs)
            LOG.success(f"Initialzed chromadb in host: {self._host}, port: {self._port}")

    @override
    def upsert(self, collection_name: str, data: List[dict]) -> bool:
        try:
            # NOTE chromadb only support single embedding for each collection
            if not data: return
            data_embeddings = data[0].get('embedding', {})
            if not data_embeddings: return
            embed_keys = list(data_embeddings.keys())
            for embed_key in embed_keys:
                if embed_key not in self._embed_datatypes:
                    raise ValueError(f"Embed key {embed_key} not found in embed_datatypes")
                collection = self._client.get_or_create_collection(
                    name=self._gen_collection_name(collection_name, embed_key), configuration=self._index_kwargs)
                for i in range(0, len(data), INSERT_BATCH_SIZE):
                    collection.upsert(**self._serialize_data(data[i: i + INSERT_BATCH_SIZE], embed_key))
            return True
        except Exception as e:
            LOG.error(f"[Chromadb Store - upsert] Failed to create collection {collection_name}: {e}")
            LOG.error(traceback.format_exc())
            return False

    def _serialize_data(self, data: List[dict], embed_key: str) -> List[dict]:
        res = {'ids': [], 'embeddings': [], 'metadatas': []}
        for d in data:
            res['ids'].append(d.get('uid'))
            res['embeddings'].append(d.get('embedding', {}).get(embed_key))
            res['metadatas'].append({self._gen_global_meta_key(k): v for k, v in d.get('global_meta', {}).items()
                                     if k in self._global_metadata_desc})
        return res

    @override
    def delete(self, collection_name: str, criteria: Optional[dict] = None, **kwargs) -> bool:
        try:
            if not criteria:
                for embed_key in self._embed_datatypes.keys():
                    try:
                        self._client.delete_collection(name=self._gen_collection_name(collection_name, embed_key))
                    except Exception:
                        continue
                return True
            else:
                filters = self._construct_criteria(criteria)
                for embed_key in self._embed_datatypes.keys():
                    collection = self._client.get_collection(name=self._gen_collection_name(collection_name, embed_key))
                    collection.delete(**filters)
                return True
        except Exception as e:
            LOG.error(f"[Chromadb Store - delete] Failed to delete collection {collection_name}: {e}")
            LOG.error(traceback.format_exc())
            return False

    @override
    def get(self, collection_name: str, criteria: Optional[dict] = None, **kwargs) -> List[dict]:
        try:
            filters = self._construct_criteria(criteria) if criteria else {}
            all_data = []
            for key in self._embed_datatypes:
                try:
                    coll = self._client.get_collection(
                        name=self._gen_collection_name(collection_name, key)
                    )
                    data = coll.get(include=['metadatas', 'embeddings'], **filters)
                    all_data.append((key, data))
                except Exception:
                    continue

            res: Dict[str, Dict[str, Any]] = defaultdict(lambda: {
                'uid': None, 'global_meta': {}, 'embedding': {}})
            for embed_key, data in all_data:
                ids = data['ids']
                metas = data['metadatas']
                embs = data['embeddings']

                for uid, meta, emb in zip(ids, metas, embs):
                    entry = res[uid]
                    entry['uid'] = uid
                    if not entry['global_meta']:
                        entry['global_meta'] = {
                            k[len(GLOBAL_META_KEY_PREFIX):]: v
                            for k, v in meta.items()
                        }
                    entry['embedding'][embed_key] = list(emb)
            return list(res.values())
        except Exception as e:
            LOG.error(f"[ChromadbStore - get] task fail: {e}")
            LOG.error(traceback.format_exc())

    @override
    def search(self, collection_name: str, query_embedding: List[float], embed_key: str, topk: Optional[int] = 10,
               filters: Optional[Dict[str, Union[str, int, List, Set]]] = None,
               **kwargs) -> List[dict]:
        try:
            collection = self._client.get_collection(name=self._gen_collection_name(collection_name, embed_key))

            filters = self._construct_filter_expr(filters) if filters else {}
            query_results = collection.query(query_embeddings=[query_embedding], n_results=topk, **filters)
            res = []
            for i, r_list in enumerate(query_results['ids']):
                for j, uid in enumerate(r_list):
                    dis = query_results['distances'][i][j]
                    res.append({'uid': uid, 'score': 1 - dis})
            return res
        except Exception as e:
            LOG.error(f"[ChromadbStore - search] task fail: {e}")
            LOG.error(traceback.format_exc())

    def _construct_criteria(self, criteria: dict) -> dict:
        res = {}
        if self._primary_key in criteria:
            res['ids'] = criteria[self._primary_key]
        else:
            res['where'] = {}
            for key, vaule in criteria.items():
                if key not in self._global_metadata_desc:
                    continue
                field_key = self._gen_global_meta_key(key)
                if isinstance(vaule, list):
                    res['where'][field_key] = {'$in': vaule}
                elif isinstance(vaule, str):
                    res['where'][field_key] = {'$eq': vaule}
                else:
                    raise ValueError(f'invalid criteria type: {type(vaule)}')
        return res

    def _construct_filter_expr(self, filters: Dict[str, Union[str, int, List, Set]]) -> str:
        ret = {}
        for name, candidates in filters.items():
            desc = self._global_metadata_desc.get(name)
            if not desc:
                raise ValueError(f'cannot find desc of field [{name}]')
            key = self._gen_global_meta_key(name)
            if isinstance(candidates, str):
                candidates = [candidates]
            elif (not isinstance(candidates, List)) and (not isinstance(candidates, Set)):
                candidates = list(candidates)
            ret[key] = {'$in': candidates}
        return {'where': ret}

    def _gen_global_meta_key(self, k: str) -> str:
        return GLOBAL_META_KEY_PREFIX + k

    def _gen_collection_name(self, collection_name: str, embed_key: str) -> str:
        return collection_name + '_' + embed_key + "_embed"

lazyllm.tools.rag.store.MilvusStore

Bases: LazyLLMStoreBase

Source code in lazyllm/tools/rag/store/vector/milvus_store.py
class MilvusStore(LazyLLMStoreBase):
    capability = StoreCapability.VECTOR
    need_embedding = True
    supports_index_registration = False

    def __init__(self, uri: str = '', db_name: str = 'lazyllm', index_kwargs: Optional[Union[Dict, List]] = None,
                 client_kwargs: Optional[Dict] = None):
        # one database, different collection for each group (for standalone, add prefix to collection name)
        # when there's data need upsert, collection creation happen.
        self._uri = uri
        self._db_name = db_name
        self._index_kwargs = index_kwargs
        self._client_kwargs = client_kwargs or {}
        self._primary_key = 'uid'
        self._client = None
        if self._uri and parse.urlparse(self._uri).scheme.lower() in ['unix', 'http', 'https', 'tcp', 'grpc']:
            self._is_remote = True
        else:
            self._is_remote = False

    @property
    def dir(self):
        if self._is_remote: return None
        p = Path(self._uri)
        p = p if p.suffix else (p / "milvus.db")
        return str(p.resolve(strict=False))

    @override
    def connect(self, embed_dims: Optional[Dict[str, int]] = None,
                embed_datatypes: Optional[Dict[str, DataType]] = None,
                global_metadata_desc: Optional[Dict[str, GlobalMetadataDesc]] = None, **kwargs):
        self._embed_dims = embed_dims or {}
        self._embed_datatypes = embed_datatypes or {}
        self._global_metadata_desc = global_metadata_desc or {}
        self._set_constants()
        self._connect()
        LOG.info("[Milvus Vector Store] init success!")
        self._disconnect()

    def _connect(self):
        try:
            self._client = pymilvus.MilvusClient(uri=self._uri, **self._client_kwargs)
            if self._is_remote and self._db_name:
                existing_dbs = self._client.list_databases()
                if self._db_name not in existing_dbs:
                    self._client.create_database(self._db_name)
                self._client.using_database(self._db_name)
        except Exception as e:
            LOG.error(f'[Milvus Store - connect] error: {e}')

    def _disconnect(self):
        try:
            if self._client:
                self._client.close()
                self._client = None
        except Exception as e:
            LOG.error(f'[Milvus Store - disconnect] error: {e}')

    @override
    def upsert(self, collection_name: str, data: List[dict]) -> bool:
        try:
            if not data: return
            data_embeddings = data[0].get('embedding', {})
            if not data_embeddings: return
            self._connect()
            if not self._client.has_collection(collection_name):
                embed_kwargs = {}
                for embed_key in data_embeddings.keys():
                    assert self._embed_datatypes.get(embed_key), \
                        f'cannot find embedding params for embed [{embed_key}]'
                    if embed_key not in embed_kwargs:
                        embed_kwargs[embed_key] = {'dtype': self._type2milvus[self._embed_datatypes[embed_key]]}
                    if self._embed_dims.get(embed_key): embed_kwargs[embed_key]['dim'] = self._embed_dims[embed_key]
                self._create_collection(collection_name, embed_kwargs)

            for i in range(0, len(data), MILVUS_UPSERT_BATCH_SIZE):
                self._client.upsert(collection_name=collection_name,
                                    data=[self._serialize_data(d) for d in data[i:i + MILVUS_UPSERT_BATCH_SIZE]])
            self._disconnect()
            return True
        except Exception as e:
            LOG.error(f'[Milvus Store - upsert] error: {e}')
            LOG.error(traceback.format_exc())
            self._disconnect()
            return False

    @override
    def delete(self, collection_name: str, criteria: Optional[dict] = None, **kwargs) -> bool:
        try:
            self._connect()
            if not self._client.has_collection(collection_name):
                return True
            self._client.load_collection(collection_name)
            if not criteria:
                self._client.drop_collection(collection_name=collection_name)
            else:
                self._client.delete(collection_name=collection_name, **self._construct_criteria(criteria))
            self._disconnect()
            return True
        except Exception as e:
            LOG.error(f'[Milvus Store - delete] error: {e}')
            self._disconnect()
            return False

    @override
    def get(self, collection_name: str, criteria: Optional[dict] = None, **kwargs) -> List[dict]:
        try:
            self._connect()
            if not self._client.has_collection(collection_name):
                return []
            self._client.load_collection(collection_name)
            col_desc = self._client.describe_collection(collection_name=collection_name)
            field_names = [field.get('name') for field in col_desc.get('fields', [])
                           if field.get('name').startswith(EMBED_PREFIX)]
            if criteria and self._primary_key in criteria:
                res = self._client.get(collection_name=collection_name, ids=criteria[self._primary_key])
            else:
                filters = self._construct_criteria(criteria) if criteria else {}
                if version.parse(pymilvus.__version__) >= version.parse('2.4.11'):
                    iterator = self._client.query_iterator(collection_name=collection_name,
                                                           batch_size=MILVUS_PAGINATION_OFFSET,
                                                           output_fields=field_names, **filters)
                    res = []
                    while True:
                        result = iterator.next()
                        if not result:
                            iterator.close()
                            break
                        res += result
                else:
                    res = self._client.query(collection_name=collection_name, output_fields=field_names, **filters)
            self._disconnect()
            return [self._deserialize_data(r) for r in res]
        except Exception as e:
            LOG.error(f'[Milvus Store - get] error: {e}')
            self._disconnect()
            return []

    def _set_constants(self):
        self._type2milvus = {
            DataType.VARCHAR: pymilvus.DataType.VARCHAR,
            DataType.ARRAY: pymilvus.DataType.ARRAY,
            DataType.FLOAT_VECTOR: pymilvus.DataType.FLOAT_VECTOR,
            DataType.INT32: pymilvus.DataType.INT32,
            DataType.INT64: pymilvus.DataType.INT64,
            DataType.SPARSE_FLOAT_VECTOR: pymilvus.DataType.SPARSE_FLOAT_VECTOR,
            DataType.STRING: pymilvus.DataType.STRING,
        }
        self._builtin_keys = {
            'uid': {'dtype': pymilvus.DataType.VARCHAR, 'max_length': 256, 'is_primary': True}
        }
        self._constant_fields = self._get_constant_fields()

    def _get_constant_fields(self) -> list:
        field_list = []
        for k, kws in self._builtin_keys.items():
            field_list.append(pymilvus.FieldSchema(name=k, **kws))
        for k, desc in self._global_metadata_desc.items():
            field_name = self._gen_global_meta_key(k)
            if desc.data_type == DataType.ARRAY:
                if desc.element_type is None:
                    raise ValueError(f'Milvus field [{field_name}]: '
                                     '`element_type` is required when `data_type` is ARRAY.')
                field_args = {'element_type': self._type2milvus[desc.element_type], 'max_capacity': desc.max_size}
                if desc.element_type == DataType.VARCHAR: field_args['max_length'] = 65535
            elif desc.data_type == DataType.VARCHAR:
                field_args = {'max_length': desc.max_size}
            else:
                field_args = {}
            field_list.append(pymilvus.FieldSchema(name=field_name, dtype=self._type2milvus[desc.data_type],
                                                   default_value=desc.default_value, **field_args))
        return field_list

    def _create_collection(self, collection_name: str, embed_kwargs: Dict[str, Dict]):  # noqa: C901
        field_list = copy.deepcopy(self._constant_fields)
        index_params = self._client.prepare_index_params()
        for k, kws in embed_kwargs.items():
            embed_field_name = self._gen_embed_key(k)
            field_list.append(pymilvus.FieldSchema(name=embed_field_name, **kws))
            index_params.add_index(field_name=embed_field_name, **kws)
            if isinstance(self._index_kwargs, list):
                for item in self._index_kwargs:
                    embed_key = item.get('embed_key', None)
                    if not embed_key:
                        raise ValueError(f'cannot find `embed_key` in `index_kwargs` of `{item}`')
                    if embed_key == k:
                        index_kwarg = item.copy()
                        index_kwarg.pop('embed_key', None)
                        index_params.add_index(field_name=embed_field_name, **index_kwarg)
                        break
            elif isinstance(self._index_kwargs, dict):
                index_params.add_index(field_name=embed_field_name, **self._index_kwargs)
        schema = pymilvus.CollectionSchema(fields=field_list, auto_id=False, enable_dynamic_field=False)
        self._client.create_collection(collection_name=collection_name, schema=schema, index_params=index_params)

    def _serialize_data(self, d: dict) -> dict:
        # only keep primary_key, embedding and global_meta
        res = {
            self._primary_key: d.get(self._primary_key, '')
        }
        for embed_key, value in d.get('embedding', {}).items():
            res[self._gen_embed_key(embed_key)] = value
        global_meta = d.get('global_meta', {})
        for name, desc in self._global_metadata_desc.items():
            value = global_meta.get(name, desc.default_value)
            if value is not None:
                res[self._gen_global_meta_key(name)] = value
        return res

    def _deserialize_data(self, d: dict) -> dict:
        res = {
            self._primary_key: d.get(self._primary_key, ''),
            'embedding': {}
        }
        for k, v in d.items():
            if k.startswith(EMBED_PREFIX):
                res['embedding'][k[len(EMBED_PREFIX):]] = v
        return res

    def _gen_embed_key(self, k: str) -> str:
        return EMBED_PREFIX + k

    def _gen_global_meta_key(self, k: str) -> str:
        return GLOBAL_META_KEY_PREFIX + k

    def _construct_criteria(self, criteria: dict) -> dict:
        res = {}
        criteria = dict(criteria)
        if self._primary_key in criteria:
            res['ids'] = criteria[self._primary_key]
        else:
            filter_str = ''
            for key, vaule in criteria.items():
                if key not in self._global_metadata_desc:
                    continue
                field_name = self._gen_global_meta_key(key)
                if len(filter_str) > 0:
                    filter_str += ' and '
                if isinstance(vaule, list):
                    filter_str += f'{field_name} in {vaule}'
                elif isinstance(vaule, str):
                    filter_str += f'{field_name} == "{vaule}"'
                else:
                    raise ValueError(f'invalid criteria type: {type(vaule)}')
            res['filter'] = filter_str
        return res

    @override
    def search(self, collection_name: str, query_embedding: Union[dict, List[float]], topk: int,
               filters: Optional[Dict[str, Union[List, set]]] = None, embed_key: Optional[str] = None,
               filter_str: Optional[str] = '', **kwargs) -> List[dict]:
        self._connect()
        if not embed_key or embed_key not in self._embed_datatypes:
            raise ValueError(f'[Milvus Store - search] Not supported or None `embed_key`: {embed_key}')
        res = []
        filter_expr = self._construct_filter_expr(filters) if filters else filter_str
        results = self._client.search(collection_name=collection_name, data=[query_embedding], limit=topk,
                                      anns_field=self._gen_embed_key(embed_key),
                                      filter=filter_expr)
        if len(results) != 1:
            raise ValueError(f'number of results [{len(results)}] != expected [1]')
        for result in results[0]:
            score = result.get('distance', 0)
            uid = result.get('id', result.get(self._primary_key, ''))
            if not uid:
                continue
            res.append({'uid': uid, 'score': score})
        self._disconnect()
        return res

    def _construct_filter_expr(self, filters: Dict[str, Union[str, int, List, Set]]) -> str:
        ret_str = ''
        if not filters:
            return ret_str
        for name, candidates in filters.items():
            desc = self._global_metadata_desc.get(name)
            if not desc:
                raise ValueError(f'cannot find desc of field [{name}]')
            key = self._gen_global_meta_key(name)
            if isinstance(candidates, str):
                candidates = [candidates]
            elif (not isinstance(candidates, list)) and (not isinstance(candidates, set)):
                candidates = list(candidates)
            if desc.data_type == DataType.ARRAY:
                ret_str += f'array_contains_any({key}, {candidates}) and '
            else:
                ret_str += f'{key} in {candidates} and '
        if len(ret_str) > 0:
            return ret_str[:-5]  # truncate the last ' and '
        return ret_str

lazyllm.tools.rag.readers.ReaderBase

Bases: ModuleBase

基础文档读取器类,提供了文档加载的基本接口。继承自ModuleBase,使用LazyLLMRegisterMetaClass作为元类。

Parameters:

  • return_trace (bool, default: True ) –

    是否返回处理过程的追踪信息。默认为True。

说明: - 提供了惰性加载和普通加载两种方式 - 子类需要实现_lazy_load_data方法 - 支持批量处理文档 - 自动转换为标准化的DocNode格式

Examples:

```python
from lazyllm.tools.rag.readers.readerBase import LazyLLMReaderBase
from lazyllm.tools.rag.doc_node import DocNode
from typing import Iterable

class CustomReader(LazyLLMReaderBase):
    def _lazy_load_data(self, file_paths: list, **kwargs) -> Iterable[DocNode]:
        for file_path in file_paths:
            # Process each file and yield DocNode
            content = self._read_file(file_path)
            yield DocNode(
                text=content,
                metadata={"source": file_path}
            )

# Create reader instance
reader = CustomReader(return_trace=True)

# Load documents
documents = reader.forward(file_paths=["doc1.txt", "doc2.txt"])
```
Source code in lazyllm/tools/rag/readers/readerBase.py
class LazyLLMReaderBase(ModuleBase, metaclass=LazyLLMRegisterMetaClass):
    """
基础文档读取器类,提供了文档加载的基本接口。继承自ModuleBase,使用LazyLLMRegisterMetaClass作为元类。

Args:
    return_trace (bool): 是否返回处理过程的追踪信息。默认为True。

**说明:**
- 提供了惰性加载和普通加载两种方式
- 子类需要实现_lazy_load_data方法
- 支持批量处理文档
- 自动转换为标准化的DocNode格式


Examples:

    ```python
    from lazyllm.tools.rag.readers.readerBase import LazyLLMReaderBase
    from lazyllm.tools.rag.doc_node import DocNode
    from typing import Iterable

    class CustomReader(LazyLLMReaderBase):
        def _lazy_load_data(self, file_paths: list, **kwargs) -> Iterable[DocNode]:
            for file_path in file_paths:
                # Process each file and yield DocNode
                content = self._read_file(file_path)
                yield DocNode(
                    text=content,
                    metadata={"source": file_path}
                )

    # Create reader instance
    reader = CustomReader(return_trace=True)

    # Load documents
    documents = reader.forward(file_paths=["doc1.txt", "doc2.txt"])
    ```
    """
    def __init__(self, *args, return_trace: bool = True, **kwargs):
        super().__init__(return_trace=return_trace)

    def _lazy_load_data(self, *args, **load_kwargs) -> Iterable[DocNode]:
        raise NotImplementedError(f"{self.__class__.__name__} does not implement lazy_load_data method.")

    def _load_data(self, *args, **load_kwargs) -> List[DocNode]:
        return list(self._lazy_load_data(*args, **load_kwargs))

    def forward(self, *args, **kwargs) -> List[DocNode]:
        return self._load_data(*args, **kwargs)

lazyllm.tools.rag.readers.readerBase.LazyLLMReaderBase

Bases: ModuleBase

基础文档读取器类,提供了文档加载的基本接口。继承自ModuleBase,使用LazyLLMRegisterMetaClass作为元类。

Parameters:

  • return_trace (bool, default: True ) –

    是否返回处理过程的追踪信息。默认为True。

说明: - 提供了惰性加载和普通加载两种方式 - 子类需要实现_lazy_load_data方法 - 支持批量处理文档 - 自动转换为标准化的DocNode格式

Examples:

```python
from lazyllm.tools.rag.readers.readerBase import LazyLLMReaderBase
from lazyllm.tools.rag.doc_node import DocNode
from typing import Iterable

class CustomReader(LazyLLMReaderBase):
    def _lazy_load_data(self, file_paths: list, **kwargs) -> Iterable[DocNode]:
        for file_path in file_paths:
            # Process each file and yield DocNode
            content = self._read_file(file_path)
            yield DocNode(
                text=content,
                metadata={"source": file_path}
            )

# Create reader instance
reader = CustomReader(return_trace=True)

# Load documents
documents = reader.forward(file_paths=["doc1.txt", "doc2.txt"])
```
Source code in lazyllm/tools/rag/readers/readerBase.py
class LazyLLMReaderBase(ModuleBase, metaclass=LazyLLMRegisterMetaClass):
    """
基础文档读取器类,提供了文档加载的基本接口。继承自ModuleBase,使用LazyLLMRegisterMetaClass作为元类。

Args:
    return_trace (bool): 是否返回处理过程的追踪信息。默认为True。

**说明:**
- 提供了惰性加载和普通加载两种方式
- 子类需要实现_lazy_load_data方法
- 支持批量处理文档
- 自动转换为标准化的DocNode格式


Examples:

    ```python
    from lazyllm.tools.rag.readers.readerBase import LazyLLMReaderBase
    from lazyllm.tools.rag.doc_node import DocNode
    from typing import Iterable

    class CustomReader(LazyLLMReaderBase):
        def _lazy_load_data(self, file_paths: list, **kwargs) -> Iterable[DocNode]:
            for file_path in file_paths:
                # Process each file and yield DocNode
                content = self._read_file(file_path)
                yield DocNode(
                    text=content,
                    metadata={"source": file_path}
                )

    # Create reader instance
    reader = CustomReader(return_trace=True)

    # Load documents
    documents = reader.forward(file_paths=["doc1.txt", "doc2.txt"])
    ```
    """
    def __init__(self, *args, return_trace: bool = True, **kwargs):
        super().__init__(return_trace=return_trace)

    def _lazy_load_data(self, *args, **load_kwargs) -> Iterable[DocNode]:
        raise NotImplementedError(f"{self.__class__.__name__} does not implement lazy_load_data method.")

    def _load_data(self, *args, **load_kwargs) -> List[DocNode]:
        return list(self._lazy_load_data(*args, **load_kwargs))

    def forward(self, *args, **kwargs) -> List[DocNode]:
        return self._load_data(*args, **kwargs)

lazyllm.tools.rag.readers.PandasExcelReader

Bases: LazyLLMReaderBase

用于读取 Excel 文件(.xlsx),并将内容提取为文本。

Parameters:

  • concat_rows (bool, default: True ) –

    是否将所有行拼接为一个文本块。

  • sheet_name (Optional[str], default: None ) –

    要读取的工作表名称。若为 None,则读取所有工作表。

  • pandas_config (Optional[Dict], default: None ) –

    pandas.read_excel 的可选配置项。

  • return_trace (bool, default: True ) –

    是否返回处理过程的 trace。

Source code in lazyllm/tools/rag/readers/pandasReader.py
class PandasExcelReader(LazyLLMReaderBase):
    """用于读取 Excel 文件(.xlsx),并将内容提取为文本。

Args:
    concat_rows (bool): 是否将所有行拼接为一个文本块。
    sheet_name (Optional[str]): 要读取的工作表名称。若为 None,则读取所有工作表。
    pandas_config (Optional[Dict]): pandas.read_excel 的可选配置项。
    return_trace (bool): 是否返回处理过程的 trace。
"""
    def __init__(self, concat_rows: bool = True, sheet_name: Optional[str] = None,
                 pandas_config: Optional[Dict] = None, return_trace: bool = True) -> None:
        super().__init__(return_trace=return_trace)
        self._concat_rows = concat_rows
        self._sheet_name = sheet_name
        self._pandas_config = pandas_config or {}

    def _load_data(self, file: Path, fs: Optional[AbstractFileSystem] = None) -> List[DocNode]:
        openpyxl_spec = importlib.util.find_spec("openpyxl")
        if openpyxl_spec is not None: pass
        else: raise ImportError("Please install openpyxl to read Excel files. "
                                "You can install it with `pip install openpyxl`")

        if not isinstance(file, Path): file = Path(file)
        if fs:
            with fs.open(file) as f:
                dfs = pd.read_excel(f, self._sheet_name, **self._pandas_config)
        else:
            dfs = pd.read_excel(file, self._sheet_name, **self._pandas_config)

        documents = []
        if isinstance(dfs, pd.DataFrame):
            df = dfs.fillna("")
            text_list = (df.astype(str).apply(lambda row: " ".join(row.values), axis=1).tolist())

            if self._concat_rows: documents.append(DocNode(text="\n".join(text_list)))
            else: documents.extend([DocNode(text=text) for text in text_list])
        else:
            for df in dfs.values():
                df = df.fillna("")
                text_list = (df.astype(str).apply(lambda row: " ".join(row), axis=1).tolist())

                if self._concat_rows: documents.append(DocNode(text="\n".join(text_list)))
                else: documents.extend([DocNode(text=text) for text in text_list])

        return documents

lazyllm.tools.rag.readers.PDFReader

Bases: LazyLLMReaderBase

用于读取 PDF 文件并提取其中的文本内容。

Parameters:

  • return_full_document (bool, default: False ) –

    是否将整份 PDF 合并为一个文档节点。若为 False,则每页作为一个节点。

  • return_trace (bool, default: True ) –

    是否返回处理过程的 trace,默认为 True。

Source code in lazyllm/tools/rag/readers/pdfReader.py
class PDFReader(LazyLLMReaderBase):
    """用于读取 PDF 文件并提取其中的文本内容。

Args:
    return_full_document (bool): 是否将整份 PDF 合并为一个文档节点。若为 False,则每页作为一个节点。
    return_trace (bool): 是否返回处理过程的 trace,默认为 True。
"""
    def __init__(self, return_full_document: bool = False, return_trace: bool = True) -> None:
        super().__init__(return_trace=return_trace)
        self._return_full_document = return_full_document

    @retry(stop=stop_after_attempt(RETRY_TIMES))
    def _load_data(self, file: Path, fs: Optional[AbstractFileSystem] = None) -> List[DocNode]:
        if not isinstance(file, Path): file = Path(file)

        fs = fs or get_default_fs()
        with fs.open(file, 'rb') as fp:
            stream = fp if is_default_fs(fs) else io.BytesIO(fp.read())
            pdf = pypdf.PdfReader(stream)
            num_pages = len(pdf.pages)
            docs = []
            if self._return_full_document:
                text = "\n".join(pdf.pages[page].extract_text() for page in range(num_pages))
                docs.append(DocNode(text=text))
            else:
                for page in range(num_pages):
                    page_text = pdf.pages[page].extract_text()
                    page_label = pdf.page_labels[page]
                    metadata = {"page_label": page_label}
                    docs.append(DocNode(text=page_text, metadata=metadata))
            return docs

lazyllm.tools.rag.readers.PPTXReader

Bases: LazyLLMReaderBase

用于解析 PPTX(PowerPoint)文件的读取器,能够提取幻灯片中的文本,并对嵌入图像进行视觉描述生成。

Parameters:

  • return_trace (bool, default: True ) –

    是否记录处理过程的 trace,默认为 True。

Source code in lazyllm/tools/rag/readers/pptxReader.py
class PPTXReader(LazyLLMReaderBase):
    """用于解析 PPTX(PowerPoint)文件的读取器,能够提取幻灯片中的文本,并对嵌入图像进行视觉描述生成。

Args:
    return_trace (bool): 是否记录处理过程的 trace,默认为 True。
"""
    def __init__(self, return_trace: bool = True) -> None:
        try:
            thirdparty.check_packages(['python-pptx', 'torch', 'Pillow', 'transformers'])
        except ImportError:
            raise ImportError("Please install extra dependencies that are required for the "
                              "PPTXReader: `pip install torch transformers python-pptx Pillow`")

        super().__init__(return_trace=return_trace)
        model = tf.VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
        feature_extractor = tf.ViTFeatureExtractor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
        tokenizer = tf.AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")

        self._parser_config = {"feature_extractor": feature_extractor, "model": model, "tokenizer": tokenizer}

    def _caption_image(self, tmp_image_file: str) -> str:
        from PIL import Image

        model = self._parser_config['model']
        feature_extractor = self._parser_config['feature_extractor']
        tokenizer = self._parser_config['tokenizer']

        device = infer_torch_device()
        model.to(device)

        max_length = 16
        num_beams = 4
        gen_kwargs = {"max_length": max_length, "num_beams": num_beams}

        i_image = Image.open(tmp_image_file)
        if i_image.mode != "RGB": i_image = i_image.convert(mode="RGB")

        pixel_values = feature_extractor(images=[i_image], return_tensors="pt").pixel_values
        pixel_values = pixel_values.to(device)

        output_ids = model.generate(pixel_values, **gen_kwargs)

        preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
        return preds[0].strip()

    def _load_data(self, file: Path, fs: Optional[AbstractFileSystem] = None) -> List[DocNode]:
        if not isinstance(file, Path): file = Path(file)

        if fs:
            with fs.open(file) as f:
                presentation = pptx.Presentation(f)
        else:
            presentation = pptx.Presentation(file)

        result = ""
        for i, slide in enumerate(presentation.slides):
            result += f"\n\nSlide #{i}: \n"
            for shape in slide.shapes:
                if hasattr(shape, "image"):
                    image = shape.image
                    image_bytes = image.blob
                    f = tempfile.NamedTemporaryFile("wb", delete=False)
                    try:
                        f.write(image_bytes)
                        f.close()
                        result += f"\n Image: {self._caption_image(f.name)}\n\n"
                    finally:
                        os.unlink(f.name)

                if hasattr(shape, "text"): result += f"{shape.text}\n"
        return [DocNode(text=result)]

lazyllm.tools.rag.readers.VideoAudioReader

Bases: LazyLLMReaderBase

用于从视频或音频文件中提取语音内容的读取器,依赖 OpenAI 的 Whisper 模型进行语音识别。

Parameters:

  • model_version (str, default: 'base' ) –

    Whisper 模型的版本(如 "base", "small", "medium", "large"),默认为 "base"。

  • return_trace (bool, default: True ) –

    是否返回处理过程的 trace,默认为 True。

Source code in lazyllm/tools/rag/readers/videoAudioReader.py
class VideoAudioReader(LazyLLMReaderBase):
    """用于从视频或音频文件中提取语音内容的读取器,依赖 OpenAI 的 Whisper 模型进行语音识别。

Args:
    model_version (str): Whisper 模型的版本(如 "base", "small", "medium", "large"),默认为 "base"。
    return_trace (bool): 是否返回处理过程的 trace,默认为 True。
"""
    def __init__(self, model_version: str = "base", return_trace: bool = True) -> None:
        super().__init__(return_trace=return_trace)
        self._model_version = model_version

        try:
            import whisper
        except ImportError:
            raise ImportError("Please install OpenAI whisper model "
                              "`pip install openai-whisper` to use the model")

        model = whisper.load_model(self._model_version)
        self._parser_config = {"model": model}

    def _load_data(self, file: Path, fs: Optional[AbstractFileSystem] = None) -> List[DocNode]:
        import whisper

        if not isinstance(file, Path): file = Path(file)

        if file.name.endswith("mp4"):
            try:
                from pydub import AudioSegment
            except ImportError:
                raise ImportError("Please install pydub `pip install pydub`")

            if fs:
                with fs.open(file, 'rb') as f:
                    video = AudioSegment.from_file(f, format="mp4")
            else:
                video = AudioSegment.from_file(file, format="mp4")

            audio = video.split_to_mono()[0]
            file_str = str(file)[:-4] + ".mp3"
            audio.export(file_str, format="mp3")

        model = cast(whisper.Whisper, self._parser_config["model"])
        result = model.transcribe(str(file))

        transcript = result['text']
        return [DocNode(text=transcript)]

lazyllm.tools.SqlManager

Bases: DBManager

SqlManager是与数据库进行交互的专用工具。它提供了连接数据库,设置、创建、检查数据表,插入数据,执行查询的方法。

Parameters:

  • db_type (str) –

    "PostgreSQL","SQLite", "MySQL", "MSSQL"。注意当类型为"SQLite"时,db_name为文件路径或者":memory:"

  • user (str) –

    用户名

  • password (str) –

    密码

  • host (str) –

    主机名或IP

  • port (int) –

    端口号

  • db_name (str) –

    数据仓库名

  • **options_str (str, default: None ) –

    k1=v1&k2=v2形式表示的选项设置

Source code in lazyllm/tools/sql/sql_manager.py
class SqlManager(DBManager):
    """SqlManager是与数据库进行交互的专用工具。它提供了连接数据库,设置、创建、检查数据表,插入数据,执行查询的方法。

Arguments:
    db_type (str): "PostgreSQL","SQLite", "MySQL", "MSSQL"。注意当类型为"SQLite"时,db_name为文件路径或者":memory:"
    user (str): 用户名
    password (str): 密码
    host (str): 主机名或IP
    port (int): 端口号
    db_name (str): 数据仓库名
    **options_str (str): k1=v1&k2=v2形式表示的选项设置
"""
    DB_TYPE_SUPPORTED = set(["postgresql", "mysql", "mssql", "sqlite", "mysql+pymysql"])
    DB_DRIVER_MAP = {"mysql": "pymysql"}
    PYTYPE_TO_SQL_MAP = {
        "integer": sqlalchemy.Integer,
        "string": sqlalchemy.Text,
        "text": sqlalchemy.Text,
        "boolean": sqlalchemy.Boolean,
        "float": sqlalchemy.Float,
        "datetime": sqlalchemy.DateTime,
        "bytes": sqlalchemy.LargeBinary,
        "bool": sqlalchemy.Boolean,
        "date": sqlalchemy.Date,
        "time": sqlalchemy.Time,
        "list": sqlalchemy.ARRAY,
        "dict": sqlalchemy.JSON,
        "uuid": sqlalchemy.Uuid,
    }

    def __init__(self, db_type: str, user: str, password: str, host: str, port: int, db_name: str, *,
                 options_str: str = None, tables_info_dict: Dict = None):
        db_type = db_type.lower()
        if db_type not in self.DB_TYPE_SUPPORTED:
            raise ValueError(f"{db_type} not supported")
        super().__init__(db_type)
        self._user = user
        self._password = password
        self._host = host
        self._port = port
        self._db_name = db_name
        self._tables_desc_dict = {}
        self._engine = None
        self._visible_tables = None
        self._metadata = sqlalchemy.MetaData()
        self._options_str = options_str
        if tables_info_dict:
            self._init_tables_by_info(tables_info_dict)

    def _init_tables_by_info(self, tables_info_dict):
        try:
            tables_info = TablesInfo.model_validate(tables_info_dict)
            self._visible_tables = [table_info.name for table_info in tables_info.tables]
            # create table if not exist
            self._create_tables_by_info(tables_info)
            desc_dict = self._gen_desc_by_info(tables_info)
            self.set_desc(desc_dict)
        except pydantic.ValidationError as e:
            raise ValueError(f"Validate tables_info_dict failed: {str(e)}")

    def _create_tables_by_info(self, tables_info: TablesInfo):
        for table_info in tables_info.tables:
            attrs = {"__tablename__": table_info.name, "__table_args__": {"extend_existing": True},
                     "metadata": self._metadata}
            for column_info in table_info.columns:
                column_type = column_info.data_type.lower()
                is_nullable = column_info.nullable
                column_name = column_info.name
                is_primary = column_info.is_primary_key
                # Use text for unsupported column type
                real_type = self.PYTYPE_TO_SQL_MAP.get(column_type, sqlalchemy.Text)
                attrs[column_name] = sqlalchemy.Column(real_type, nullable=is_nullable, primary_key=is_primary)
            # When create dynamic class with same name, old version will be replaced
            TableClass = type(table_info.name.capitalize(), (TableBase,), attrs)
            self.create_table(TableClass)

    def _gen_desc_by_info(self, tables_info: TablesInfo) -> dict:
        desc_dict = {}
        for table_info in tables_info.tables:
            table_comment = ""
            if table_info.comment:
                table_comment += f"COMMENT ON TABLE '{table_info.name}': {table_info.comment}\n"
            for column_info in table_info.columns:
                table_comment += f"COMMENT ON COLUMN '{table_info.name}.{column_info.name}': {column_info.comment}\n"
            if table_comment:
                desc_dict[table_info.name] = table_comment
        return desc_dict

    def _gen_conn_url(self) -> str:
        if self._db_type == "sqlite":
            conn_url = f"sqlite:///{self._db_name}{('?' + self._options_str) if self._options_str else ''}"
        else:
            driver = self.DB_DRIVER_MAP.get(self._db_type, "")
            password = quote_plus(self._password)
            conn_url = (f"{self._db_type}{('+' + driver) if driver else ''}://{self._user}:{password}@{self._host}"
                        f":{self._port}/{self._db_name}{('?' + self._options_str) if self._options_str else ''}")
        return conn_url

    @property
    def engine(self):
        if self._engine is None:
            self._engine = sqlalchemy.create_engine(self._gen_conn_url())
        return self._engine

    @contextmanager
    def get_session(self):
        """这是一个上下文管理器,它创建并返回一个数据库连接Session,并在完成时自动提交或回滚更改并在使用完成后自动关闭会话。

**Returns:**

- sqlalchemy.orm.Session: sqlalchemy 数据库会话
"""
        _Session = sessionmaker(bind=self.engine)
        session = _Session()
        try:
            yield session
            session.commit()
        except Exception as e:
            session.rollback()
            raise e
        finally:
            session.close()

    def check_connection(self) -> DBResult:
        """检查当前SqlManager的连接状态。

**Returns:**

- DBResult: DBResult.status 连接成功(True), 连接失败(False)。DBResult.detail 包含失败信息
"""
        try:
            with self.engine.connect() as _:
                return DBResult()
        except SQLAlchemyError as e:
            return DBResult(status=DBStatus.FAIL, detail=str(e))

    @property
    def desc(self) -> str:
        if self._desc is None:
            self.set_desc(tables_desc_dict={})
        return self._desc

    def set_desc(self, tables_desc_dict: dict = {}):  # noqa B006
        """对于SqlManager搭配LLM使用自然语言查询的表项设置其描述,尤其当其表名、列名及取值不具有自解释能力时。
例如:
数据表Document的status列取值包括: "waiting", "working", "success", "failed",tables_desc_dict参数应为 {"Document": "status列取值包括: waiting, working, success, failed"}

Args:
    tables_desc_dict (dict): 表项的补充说明
"""
        self._desc = ""
        if not isinstance(tables_desc_dict, dict):
            raise ValueError(f"desc type {type(tables_desc_dict)} not supported")
        self._tables_desc_dict = tables_desc_dict
        if len(self.visible_tables) == 0:
            return
        # Generate desc according to table schema and comment
        self._desc = "The tables description is as follows\n```\n"
        for table_name in self.visible_tables:
            self._desc += f"Table {table_name}\n(\n"
            TableCls = self.get_table_orm_class(table_name)
            if TableCls is None:
                # The table could be dropped in other session
                continue
            table_columns = TableCls.__table__.columns
            for i, column in enumerate(table_columns):
                self._desc += f" {column.name} {column.type}"
                if i != len(table_columns) - 1:
                    self._desc += ","
                self._desc += "\n"
            self._desc += ");\n"
            if table_name in tables_desc_dict:
                self._desc += tables_desc_dict[table_name] + "\n\n"
        self._desc += "```\n"

    @property
    def visible_tables(self):
        if self._visible_tables is None:
            self._visible_tables = self.get_all_tables()
        return self._visible_tables

    @visible_tables.setter
    def visible_tables(self, visible_tables: list):
        all_tables = set(self.get_all_tables())
        for ele in visible_tables:
            if ele not in all_tables:
                raise ValueError(f"Table {ele} not found in database")
        self._visible_tables = visible_tables
        self.set_desc(self._tables_desc_dict)

    def _refresh_metadata(self, only=None):
        # refresh metadata in case of deleting/creating table in other session
        self._metadata.clear()
        self._metadata.reflect(bind=self.engine, only=only)

    def get_all_tables(self) -> list:
        """返回当前数据库中的所有表名。
"""
        self._refresh_metadata()
        return list(self._metadata.tables.keys())

    def get_table_orm_class(self, table_name):
        """返回数据表名对应的sqlalchemy orm类。结合get_session,进行orm操作
"""
        self._refresh_metadata(only=[table_name])
        Base = automap_base(metadata=self._metadata)
        Base.prepare()
        return getattr(Base.classes, table_name, None)

    def execute_commit(self, statement: str):
        """执行无返回的sql脚本并提交更改。
"""
        with self.get_session() as session:
            session.execute(sqlalchemy.text(statement))

    def execute_query(self, statement: str) -> str:
        """执行sql查询脚本并以JSON字符串返回结果。
"""
        statement = re.sub(r"/\*.*?\*/", "", statement, flags=re.DOTALL).strip()
        create_table_pattern = r".*\s*create\s+table\s+.*"
        drop_table_pattern = r".*\s*drop\s+table\s+.*"
        statement_lower = statement.lower()
        if re.match(create_table_pattern, statement_lower):
            return f"Create table not supported. Original statement: {statement}"
        elif re.match(drop_table_pattern, statement_lower):
            return f"Drop table not supported. Original statement: {statement}"
        try:
            result = []
            _Session = sessionmaker(bind=self.engine)
            # Use original session without post commit
            with _Session() as session:
                cursor_result = session.execute(sqlalchemy.text(statement))
                columns = list(cursor_result.keys())
                result = [dict(zip(columns, row)) for row in cursor_result]
            str_result = json.dumps(result, ensure_ascii=False, default=self._serialize_uncommon_type)
        except Exception as e:
            str_result = f"Execute SQL ERROR: {str(e)}"
        return str_result

    def _create_by_script(self, table: str) -> DBResult:
        status = DBStatus.SUCCESS
        detail = "Success"
        try:
            with self.engine.connect() as conn:
                conn.execute(sqlalchemy.text(table))
                conn.commit()
        except OperationalError as e:
            status = DBStatus.FAIL
            detail = f"ERROR: {str(e)}"
        return DBResult(status=status, detail=detail)

    def _create_by_api(self, table: Union[DeclarativeBase, DeclarativeMeta]) -> DBResult:
        table.metadata.create_all(bind=self.engine, checkfirst=True)
        return DBResult()

    def create_table(self, table: Union[str, Type[DeclarativeBase], DeclarativeMeta]) -> DBResult:
        """创建数据表

Args:
    table (str/Type[DeclarativeBase]/DeclarativeMeta): 数据表schema。支持三种参数类型:类型为str的sql语句,继承自DeclarativeBase或继承自declarative_base()的ORM类
"""
        status = DBStatus.SUCCESS
        detail = "Success"
        if isinstance(table, str):
            return self._create_by_script(table)
        # Support DeclarativeMeta created by declarative_base() which is deprecated since: 2.0
        elif issubclass(table, DeclarativeBase) or isinstance(table, DeclarativeMeta):
            return self._create_by_api(table)
        else:
            status = DBStatus.FAIL
            detail += f"Failed: Unsupported Type: {table}"
        return DBResult(status=status, detail=detail)

    def drop_table(self, table: Union[str, Type[DeclarativeBase], DeclarativeMeta]) -> DBResult:
        """删除数据表

Args:
    table (str/Type[DeclarativeBase]/DeclarativeMeta): 数据表schema。支持三种参数类型:类型为str的数据表名,继承自DeclarativeBase或继承自declarative_base()的ORM类
"""
        metadata = self._metadata
        if isinstance(table, str):
            tablename = table
        elif issubclass(table, DeclarativeBase) or isinstance(table, DeclarativeMeta):
            tablename = table.__tablename__
        else:
            return DBResult(status=DBStatus.FAIL, detail=f"{table} type unsupported")
        Table = sqlalchemy.Table(tablename, metadata, autoload_with=self.engine)
        Table.drop(self.engine, checkfirst=True)
        return DBResult()

    def insert_values(self, table_name: str, vals: List[dict]) -> DBResult:
        """批量数据插入

Args:
    table_name (str): 数据表名
    vals (List[dict]): 待插入数据,格式为[{"col_name1": v01, "col_name2": v02, ...}, {"col_name1": v11, "col_name2": v12, ...}, ...]
"""
        # Refresh metadata in case of tables created by other api
        TableCls = self.get_table_orm_class(table_name)
        if TableCls is None:
            return DBResult(status=DBStatus.FAIL, detail=f"{table_name} not found in database")
        with self.get_session() as session:
            session.bulk_insert_mappings(TableCls, vals)
        return DBResult()

check_connection()

检查当前SqlManager的连接状态。

Returns:

  • DBResult: DBResult.status 连接成功(True), 连接失败(False)。DBResult.detail 包含失败信息
Source code in lazyllm/tools/sql/sql_manager.py
    def check_connection(self) -> DBResult:
        """检查当前SqlManager的连接状态。

**Returns:**

- DBResult: DBResult.status 连接成功(True), 连接失败(False)。DBResult.detail 包含失败信息
"""
        try:
            with self.engine.connect() as _:
                return DBResult()
        except SQLAlchemyError as e:
            return DBResult(status=DBStatus.FAIL, detail=str(e))

create_table(table)

创建数据表

Parameters:

  • table (str / Type[DeclarativeBase] / DeclarativeMeta) –

    数据表schema。支持三种参数类型:类型为str的sql语句,继承自DeclarativeBase或继承自declarative_base()的ORM类

Source code in lazyllm/tools/sql/sql_manager.py
    def create_table(self, table: Union[str, Type[DeclarativeBase], DeclarativeMeta]) -> DBResult:
        """创建数据表

Args:
    table (str/Type[DeclarativeBase]/DeclarativeMeta): 数据表schema。支持三种参数类型:类型为str的sql语句,继承自DeclarativeBase或继承自declarative_base()的ORM类
"""
        status = DBStatus.SUCCESS
        detail = "Success"
        if isinstance(table, str):
            return self._create_by_script(table)
        # Support DeclarativeMeta created by declarative_base() which is deprecated since: 2.0
        elif issubclass(table, DeclarativeBase) or isinstance(table, DeclarativeMeta):
            return self._create_by_api(table)
        else:
            status = DBStatus.FAIL
            detail += f"Failed: Unsupported Type: {table}"
        return DBResult(status=status, detail=detail)

drop_table(table)

删除数据表

Parameters:

  • table (str / Type[DeclarativeBase] / DeclarativeMeta) –

    数据表schema。支持三种参数类型:类型为str的数据表名,继承自DeclarativeBase或继承自declarative_base()的ORM类

Source code in lazyllm/tools/sql/sql_manager.py
    def drop_table(self, table: Union[str, Type[DeclarativeBase], DeclarativeMeta]) -> DBResult:
        """删除数据表

Args:
    table (str/Type[DeclarativeBase]/DeclarativeMeta): 数据表schema。支持三种参数类型:类型为str的数据表名,继承自DeclarativeBase或继承自declarative_base()的ORM类
"""
        metadata = self._metadata
        if isinstance(table, str):
            tablename = table
        elif issubclass(table, DeclarativeBase) or isinstance(table, DeclarativeMeta):
            tablename = table.__tablename__
        else:
            return DBResult(status=DBStatus.FAIL, detail=f"{table} type unsupported")
        Table = sqlalchemy.Table(tablename, metadata, autoload_with=self.engine)
        Table.drop(self.engine, checkfirst=True)
        return DBResult()

execute_commit(statement)

执行无返回的sql脚本并提交更改。

Source code in lazyllm/tools/sql/sql_manager.py
    def execute_commit(self, statement: str):
        """执行无返回的sql脚本并提交更改。
"""
        with self.get_session() as session:
            session.execute(sqlalchemy.text(statement))

execute_query(statement)

执行sql查询脚本并以JSON字符串返回结果。

Source code in lazyllm/tools/sql/sql_manager.py
    def execute_query(self, statement: str) -> str:
        """执行sql查询脚本并以JSON字符串返回结果。
"""
        statement = re.sub(r"/\*.*?\*/", "", statement, flags=re.DOTALL).strip()
        create_table_pattern = r".*\s*create\s+table\s+.*"
        drop_table_pattern = r".*\s*drop\s+table\s+.*"
        statement_lower = statement.lower()
        if re.match(create_table_pattern, statement_lower):
            return f"Create table not supported. Original statement: {statement}"
        elif re.match(drop_table_pattern, statement_lower):
            return f"Drop table not supported. Original statement: {statement}"
        try:
            result = []
            _Session = sessionmaker(bind=self.engine)
            # Use original session without post commit
            with _Session() as session:
                cursor_result = session.execute(sqlalchemy.text(statement))
                columns = list(cursor_result.keys())
                result = [dict(zip(columns, row)) for row in cursor_result]
            str_result = json.dumps(result, ensure_ascii=False, default=self._serialize_uncommon_type)
        except Exception as e:
            str_result = f"Execute SQL ERROR: {str(e)}"
        return str_result

get_all_tables()

返回当前数据库中的所有表名。

Source code in lazyllm/tools/sql/sql_manager.py
    def get_all_tables(self) -> list:
        """返回当前数据库中的所有表名。
"""
        self._refresh_metadata()
        return list(self._metadata.tables.keys())

get_session()

这是一个上下文管理器,它创建并返回一个数据库连接Session,并在完成时自动提交或回滚更改并在使用完成后自动关闭会话。

Returns:

  • sqlalchemy.orm.Session: sqlalchemy 数据库会话
Source code in lazyllm/tools/sql/sql_manager.py
    @contextmanager
    def get_session(self):
        """这是一个上下文管理器,它创建并返回一个数据库连接Session,并在完成时自动提交或回滚更改并在使用完成后自动关闭会话。

**Returns:**

- sqlalchemy.orm.Session: sqlalchemy 数据库会话
"""
        _Session = sessionmaker(bind=self.engine)
        session = _Session()
        try:
            yield session
            session.commit()
        except Exception as e:
            session.rollback()
            raise e
        finally:
            session.close()

get_table_orm_class(table_name)

返回数据表名对应的sqlalchemy orm类。结合get_session,进行orm操作

Source code in lazyllm/tools/sql/sql_manager.py
    def get_table_orm_class(self, table_name):
        """返回数据表名对应的sqlalchemy orm类。结合get_session,进行orm操作
"""
        self._refresh_metadata(only=[table_name])
        Base = automap_base(metadata=self._metadata)
        Base.prepare()
        return getattr(Base.classes, table_name, None)

insert_values(table_name, vals)

批量数据插入

Parameters:

  • table_name (str) –

    数据表名

  • vals (List[dict]) –

    待插入数据,格式为[{"col_name1": v01, "col_name2": v02, ...}, {"col_name1": v11, "col_name2": v12, ...}, ...]

Source code in lazyllm/tools/sql/sql_manager.py
    def insert_values(self, table_name: str, vals: List[dict]) -> DBResult:
        """批量数据插入

Args:
    table_name (str): 数据表名
    vals (List[dict]): 待插入数据,格式为[{"col_name1": v01, "col_name2": v02, ...}, {"col_name1": v11, "col_name2": v12, ...}, ...]
"""
        # Refresh metadata in case of tables created by other api
        TableCls = self.get_table_orm_class(table_name)
        if TableCls is None:
            return DBResult(status=DBStatus.FAIL, detail=f"{table_name} not found in database")
        with self.get_session() as session:
            session.bulk_insert_mappings(TableCls, vals)
        return DBResult()

set_desc(tables_desc_dict={})

对于SqlManager搭配LLM使用自然语言查询的表项设置其描述,尤其当其表名、列名及取值不具有自解释能力时。 例如: 数据表Document的status列取值包括: "waiting", "working", "success", "failed",tables_desc_dict参数应为 {"Document": "status列取值包括: waiting, working, success, failed"}

Parameters:

  • tables_desc_dict (dict, default: {} ) –

    表项的补充说明

Source code in lazyllm/tools/sql/sql_manager.py
    def set_desc(self, tables_desc_dict: dict = {}):  # noqa B006
        """对于SqlManager搭配LLM使用自然语言查询的表项设置其描述,尤其当其表名、列名及取值不具有自解释能力时。
例如:
数据表Document的status列取值包括: "waiting", "working", "success", "failed",tables_desc_dict参数应为 {"Document": "status列取值包括: waiting, working, success, failed"}

Args:
    tables_desc_dict (dict): 表项的补充说明
"""
        self._desc = ""
        if not isinstance(tables_desc_dict, dict):
            raise ValueError(f"desc type {type(tables_desc_dict)} not supported")
        self._tables_desc_dict = tables_desc_dict
        if len(self.visible_tables) == 0:
            return
        # Generate desc according to table schema and comment
        self._desc = "The tables description is as follows\n```\n"
        for table_name in self.visible_tables:
            self._desc += f"Table {table_name}\n(\n"
            TableCls = self.get_table_orm_class(table_name)
            if TableCls is None:
                # The table could be dropped in other session
                continue
            table_columns = TableCls.__table__.columns
            for i, column in enumerate(table_columns):
                self._desc += f" {column.name} {column.type}"
                if i != len(table_columns) - 1:
                    self._desc += ","
                self._desc += "\n"
            self._desc += ");\n"
            if table_name in tables_desc_dict:
                self._desc += tables_desc_dict[table_name] + "\n\n"
        self._desc += "```\n"

lazyllm.tools.rag.component.bm25.BM25

A BM25 retriever that uses the BM25 algorithm to retrieve nodes.

Source code in lazyllm/tools/rag/component/bm25.py
class BM25:
    """A BM25 retriever that uses the BM25 algorithm to retrieve nodes."""

    def __init__(
        self,
        nodes: List[DocNode],
        language: str = "en",
        topk: int = 2,
        **kwargs,
    ) -> None:
        if language == "en":
            self._stemmer = Stemmer.Stemmer("english")
            self._stopwords = language
            self._tokenizer = lambda t: t
        elif language == "zh":
            self._stemmer = None
            # TODO(ywt): after bm25s supports cn stopwards, update this
            self._stopwords = STOPWORDS_CHINESE
            self._tokenizer = lambda t: " ".join(jieba.lcut(t))
        self.topk = min(topk, len(nodes))
        self.nodes = nodes

        corpus_tokens = bm25s.tokenize(
            [self._tokenizer(node.get_text()) for node in nodes],
            stopwords=self._stopwords,
            stemmer=self._stemmer,
        )
        self.bm25 = bm25s.BM25()
        self.bm25.index(corpus_tokens)

    def retrieve(self, query: str) -> List[Tuple[DocNode, float]]:
        tokenized_query = bm25s.tokenize(
            self._tokenizer(query), stopwords=self._stopwords, stemmer=self._stemmer
        )
        indexs, scores = self.bm25.retrieve(tokenized_query, k=self.topk)
        results = []
        for idx, score in zip(indexs[0], scores[0]):
            results.append((self.nodes[idx], score))
        return results

lazyllm.tools.rag.doc_to_db.DocInfoSchemaItem

Bases: TypedDict

文档信息结构中单个字段的定义。

Parameters:

  • key (str) –

    字段名

  • desc (str) –

    字段含义描述

  • type (str) –

    字段的数据类型

Source code in lazyllm/tools/rag/doc_to_db/doc_analysis.py
class DocInfoSchemaItem(TypedDict):
    """文档信息结构中单个字段的定义。

Args:
    key (str): 字段名
    desc (str): 字段含义描述
    type (str): 字段的数据类型
"""
    key: str
    desc: str
    type: str

lazyllm.tools.rag.doc_to_db.DocGenreAnalyser

用于分析文档所属的类别,例如合同、简历、发票等。通过读取文档内容,并结合大模型判断其类型。

Parameters:

  • maximum_doc_num (int, default: 3 ) –

    最多分析的文档数量,默认是 3。

Examples:

>>> import lazyllm
>>> from lazyllm.components.doc_info_extractor import DocGenreAnalyser
>>> from lazyllm import OnlineChatModule
>>> m = OnlineChatModule(source="openai")
>>> analyser = DocGenreAnalyser()
>>> genre = analyser.analyse_doc_genre(m, "path/to/document.txt")
>>> print(genre)
contract
Source code in lazyllm/tools/rag/doc_to_db/doc_analysis.py
class DocGenreAnalyser:
    """用于分析文档所属的类别,例如合同、简历、发票等。通过读取文档内容,并结合大模型判断其类型。

Args:
    maximum_doc_num (int): 最多分析的文档数量,默认是 3。


Examples:
    >>> import lazyllm
    >>> from lazyllm.components.doc_info_extractor import DocGenreAnalyser
    >>> from lazyllm import OnlineChatModule
    >>> m = OnlineChatModule(source="openai")
    >>> analyser = DocGenreAnalyser()
    >>> genre = analyser.analyse_doc_genre(m, "path/to/document.txt")
    >>> print(genre)
    contract
    """
    ONE_DOC_TOKEN_LIMIT = 10000

    def __init__(self, maximum_doc_num=3):
        self._reader = DirectoryReader(None, {}, {})
        self._pattern = re.compile(r"```json(.+?)```", re.DOTALL)
        self._maximum_doc_num = maximum_doc_num
        self._tiktoken_tokenizer = tiktoken.encoding_for_model("gpt-3.5-turbo")
        assert self._maximum_doc_num > 0

    def gen_detection_query(self, doc_path: str):
        root_nodes = self._reader.load_data([doc_path], None)
        doc_content = ""
        for root_node in root_nodes:
            doc_content += root_node.text + "\n"
        doc_content = trim_content_by_token_num(self._tiktoken_tokenizer, doc_content, self.ONE_DOC_TOKEN_LIMIT)
        query = DOC_KWS_PROMPTS["doc_type_detection"].format(doc_content=doc_content)
        query += "\nBelow is the content of each document sample.\n\n"
        return query

    def _extract_doc_type_from_response(self, str_response: str) -> str:
        # Remove the triple backticks if present
        matches = self._pattern.findall(str_response)
        if matches:
            # Return the first match
            extracted_content = matches[0].strip()
            try:
                res_dict = json.loads(extracted_content)
                if not isinstance(res_dict, dict) or "doc_type" not in res_dict:
                    return ""
                return res_dict["doc_type"]
            except Exception as e:
                lazyllm.LOG.warning(f"Exception: {str(e)}, response_str: {str_response}")
                return ""
        else:
            return ""

    def analyse_doc_genre(self, llm: Union[OnlineChatModule, TrainableModule], doc_path: str) -> str:
        query = self.gen_detection_query(doc_path)
        response = llm(query)
        doc_genre = self._extract_doc_type_from_response(response)
        return doc_genre

lazyllm.tools.rag.doc_to_db.DocInfoSchemaAnalyser

用于从文档中抽取出关键信息字段的结构,如字段名、描述、字段类型。可用于构建信息提取模板。

Parameters:

  • maximum_doc_num (int, default: 3 ) –

    用于生成schema的最大文档数量,默认是 3。

Examples:

>>> from lazyllm.components.doc_info_extractor import DocInfoSchemaAnalyser
>>> from lazyllm import OnlineChatModule
>>> analyser = DocInfoSchemaAnalyser()
>>> m = OnlineChatModule(source="openai")
>>> schema = analyser.analyse_info_schema(m, "contract", ["doc1.txt", "doc2.txt"])
>>> print(schema)
[{'key': 'party_a', 'desc': 'The first party', 'type': 'str'}, ...]
Source code in lazyllm/tools/rag/doc_to_db/doc_analysis.py
class DocInfoSchemaAnalyser:
    """用于从文档中抽取出关键信息字段的结构,如字段名、描述、字段类型。可用于构建信息提取模板。

Args:
    maximum_doc_num (int): 用于生成schema的最大文档数量,默认是 3。


Examples:
    >>> from lazyllm.components.doc_info_extractor import DocInfoSchemaAnalyser
    >>> from lazyllm import OnlineChatModule
    >>> analyser = DocInfoSchemaAnalyser()
    >>> m = OnlineChatModule(source="openai")
    >>> schema = analyser.analyse_info_schema(m, "contract", ["doc1.txt", "doc2.txt"])
    >>> print(schema)
    [{'key': 'party_a', 'desc': 'The first party', 'type': 'str'}, ...]
    """
    ONE_DOC_TOKEN_LIMIT = 30000

    def __init__(self, maximum_doc_num=3):
        self._reader = DirectoryReader(None, {}, {})
        self._pattern = re.compile(r"```json(.+?)```", re.DOTALL)
        self._maximum_doc_num = maximum_doc_num
        self._tiktoken_tokenizer = tiktoken.encoding_for_model("gpt-3.5-turbo")
        assert self._maximum_doc_num > 0

    def _gen_first_round_query(self, doc_type: str, doc_paths: list[str]):
        doc_contents = []
        for doc_path in doc_paths:
            root_nodes = self._reader.load_data([doc_path], None)
            doc_content = ""
            for root_node in root_nodes:
                doc_content += root_node.text + "\n"
            doc_content = trim_content_by_token_num(self._tiktoken_tokenizer, doc_content, self.ONE_DOC_TOKEN_LIMIT)
            doc_contents.append(doc_content)
        query = DOC_KWS_PROMPTS["kws_generation"].format(number=len(doc_contents), doc_type=doc_type)
        query += "\nBelow is the content of each document sample.\n\n"
        for i, doc_content in enumerate(doc_contents):
            query += f"Document {i+1}:\n```\n{doc_content}\n```\n\n"
        return query

    def _extract_schema_from_response(self, str_response: str) -> List[dict]:
        # Remove the triple backticks if present
        matches = self._pattern.findall(str_response)
        empty_list = []
        if matches:
            # Return the first match
            extracted_content = matches[0].strip()
            try:
                kws_list = json.loads(extracted_content)
                # in case of the list is in a dict, unpack it
                if isinstance(kws_list, dict):
                    values = list(kws_list.values())
                    if len(values) == 1 and isinstance(values[0], list):
                        return values[0]
                if not isinstance(kws_list, list):
                    lazyllm.LOG.warning(f"Excepted original type list but got {type(kws_list)} value: {kws_list}")
                    return empty_list
                return kws_list
            except Exception as e:
                lazyllm.LOG.warning(f"Exception: {str(e)}, response_str: {str_response}")
                return empty_list
        else:
            return empty_list

    def analyse_info_schema(
        self, llm: Union[OnlineChatModule, TrainableModule], doc_type: str, doc_paths: list[str]
    ) -> DocInfoSchema:
        """分析文档信息模式的方法,用于从指定类型的文档中提取关键信息字段的结构定义。

Args:
    llm (Union[OnlineChatModule, TrainableModule]): 用于生成信息模式的LLM模型
    doc_type (str): 文档类型,用于指导LLM生成相应的信息模式
    doc_paths (list[str]): 文档路径列表,用于分析的信息来源

**Returns:**

- DocInfoSchema: 包含关键信息字段定义的模式列表,每个字段包含key、desc、type三个属性
"""
        RANDOM_SEED = 1331
        if len(doc_paths) > self._maximum_doc_num:
            doc_paths.sort()
            random.seed(RANDOM_SEED)
            doc_paths = random.sample(doc_paths, self._maximum_doc_num)
        first_round_query = self._gen_first_round_query(doc_type, doc_paths)
        first_response = llm(first_round_query)
        info_schema = self._extract_schema_from_response(first_response)
        for info_schema_item in info_schema:
            is_success, msg = validate_schema_item(info_schema_item, DocInfoSchemaItem)
            if not is_success:
                lazyllm.LOG.warning(f"Please Try Again! Invalid kws dict: {info_schema_item}, error_msg: {msg}")
                return []
        return info_schema

analyse_info_schema(llm, doc_type, doc_paths)

分析文档信息模式的方法,用于从指定类型的文档中提取关键信息字段的结构定义。

Parameters:

  • llm (Union[OnlineChatModule, TrainableModule]) –

    用于生成信息模式的LLM模型

  • doc_type (str) –

    文档类型,用于指导LLM生成相应的信息模式

  • doc_paths (list[str]) –

    文档路径列表,用于分析的信息来源

Returns:

  • DocInfoSchema: 包含关键信息字段定义的模式列表,每个字段包含key、desc、type三个属性
Source code in lazyllm/tools/rag/doc_to_db/doc_analysis.py
    def analyse_info_schema(
        self, llm: Union[OnlineChatModule, TrainableModule], doc_type: str, doc_paths: list[str]
    ) -> DocInfoSchema:
        """分析文档信息模式的方法,用于从指定类型的文档中提取关键信息字段的结构定义。

Args:
    llm (Union[OnlineChatModule, TrainableModule]): 用于生成信息模式的LLM模型
    doc_type (str): 文档类型,用于指导LLM生成相应的信息模式
    doc_paths (list[str]): 文档路径列表,用于分析的信息来源

**Returns:**

- DocInfoSchema: 包含关键信息字段定义的模式列表,每个字段包含key、desc、type三个属性
"""
        RANDOM_SEED = 1331
        if len(doc_paths) > self._maximum_doc_num:
            doc_paths.sort()
            random.seed(RANDOM_SEED)
            doc_paths = random.sample(doc_paths, self._maximum_doc_num)
        first_round_query = self._gen_first_round_query(doc_type, doc_paths)
        first_response = llm(first_round_query)
        info_schema = self._extract_schema_from_response(first_response)
        for info_schema_item in info_schema:
            is_success, msg = validate_schema_item(info_schema_item, DocInfoSchemaItem)
            if not is_success:
                lazyllm.LOG.warning(f"Please Try Again! Invalid kws dict: {info_schema_item}, error_msg: {msg}")
                return []
        return info_schema

lazyllm.tools.rag.doc_to_db.DocInfoExtractor

根据给定的字段结构(schema)从文档中抽取具体的关键信息值,返回格式为 key-value 字典。

Examples:

>>> from lazyllm.components.doc_info_extractor import DocInfoExtractor
>>> from lazyllm import OnlineChatModule
>>> extractor = DocInfoExtractor()
>>> m = OnlineChatModule(source="openai")
>>> schema = [{"key": "party_a", "desc": "Party A name", "type": "str"}]
>>> info = extractor.extract_doc_info(m, "contract.txt", schema)
>>> print(info)
{'party_a': 'ABC Corp'}
Source code in lazyllm/tools/rag/doc_to_db/doc_analysis.py
class DocInfoExtractor:
    """根据给定的字段结构(schema)从文档中抽取具体的关键信息值,返回格式为 key-value 字典。

Args:



Examples:
    >>> from lazyllm.components.doc_info_extractor import DocInfoExtractor
    >>> from lazyllm import OnlineChatModule
    >>> extractor = DocInfoExtractor()
    >>> m = OnlineChatModule(source="openai")
    >>> schema = [{"key": "party_a", "desc": "Party A name", "type": "str"}]
    >>> info = extractor.extract_doc_info(m, "contract.txt", schema)
    >>> print(info)
    {'party_a': 'ABC Corp'}
    """
    ONE_DOC_TOKEN_LIMIT = 50000

    def __init__(self):
        self._reader = DirectoryReader(None, {}, {})
        self._pattern = re.compile(r"```json(.+?)```", re.DOTALL)
        self._tiktoken_tokenizer = tiktoken.encoding_for_model("gpt-3.5-turbo")

    def _gen_extraction_query(self, doc_path: str, info_schema: DocInfoSchema, extra_desc: str) -> str:
        root_nodes = self._reader.load_data([doc_path], None)
        doc_content = ""
        for root_node in root_nodes:
            doc_content += root_node.text + "\n"
        doc_content = trim_content_by_token_num(self._tiktoken_tokenizer, doc_content, self.ONE_DOC_TOKEN_LIMIT)
        if not extra_desc:
            extra_desc = f"Extra description: \n{extra_desc}"
        query = DOC_KWS_PROMPTS["kws_extraction"].format(
            kws_desc=json.dumps(info_schema), extra_desc=extra_desc, doc_content=doc_content
        )
        return query

    def _extract_kws_value_from_response(self, str_response: str) -> dict:
        # Remove the triple backticks if present
        matches = self._pattern.findall(str_response)
        empty_dict = {}
        if matches:
            # Return the first match
            extracted_content = matches[0].strip()
            try:
                kws_value = json.loads(extracted_content)
                if not isinstance(kws_value, dict):
                    lazyllm.LOG.warning(f"Excepted original type list but got {type(kws_value)}")
                    return empty_dict
                new_dict = {k: v for k, v in kws_value.items() if (isinstance(v, str) and v and v != "None")}
                return new_dict
            except Exception as e:
                lazyllm.LOG.warning(f"Exception: {str(e)}, response_str: {str_response}")
                return empty_dict
        else:
            return empty_dict

    def _format_info_by_schema(self, info: dict, info_schema: DocInfoSchema):
        valid_keys = set([info_schema_item["key"] for info_schema_item in info_schema])
        return {k: v for k, v in info.items() if k in valid_keys}

    def extract_doc_info(
        self,
        llm: Union[OnlineChatModule, TrainableModule],
        doc_path: str,
        info_schema: DocInfoSchema,
        extra_desc: str = "",
    ) -> dict:
        """根据提供的字段结构(schema)从指定文档中抽取具体的关键信息值。

该方法使用大语言模型分析文档内容,根据预定义的字段结构提取相应的信息值,返回格式为 key-value 字典。

Args:
    llm (Union[OnlineChatModule, TrainableModule]): 用于文档信息抽取的大语言模型。
    doc_path (str): 要分析的文档路径。
    info_schema (DocInfoSchema): 字段结构定义,包含需要提取的字段信息。
    extra_desc (str, optional): 额外的描述信息,用于指导信息抽取。默认为空字符串。

Returns:
    dict: 提取出的关键信息字典,键为字段名,值为对应的信息值。
"""
        extraction_query = self._gen_extraction_query(doc_path, info_schema, extra_desc)
        response = llm(extraction_query)
        info: dict = self._extract_kws_value_from_response(response)
        info: dict = self._format_info_by_schema(info, info_schema)
        return info

extract_doc_info(llm, doc_path, info_schema, extra_desc='')

根据提供的字段结构(schema)从指定文档中抽取具体的关键信息值。

该方法使用大语言模型分析文档内容,根据预定义的字段结构提取相应的信息值,返回格式为 key-value 字典。

Parameters:

  • llm (Union[OnlineChatModule, TrainableModule]) –

    用于文档信息抽取的大语言模型。

  • doc_path (str) –

    要分析的文档路径。

  • info_schema (DocInfoSchema) –

    字段结构定义,包含需要提取的字段信息。

  • extra_desc (str, default: '' ) –

    额外的描述信息,用于指导信息抽取。默认为空字符串。

Returns:

  • dict ( dict ) –

    提取出的关键信息字典,键为字段名,值为对应的信息值。

Source code in lazyllm/tools/rag/doc_to_db/doc_analysis.py
    def extract_doc_info(
        self,
        llm: Union[OnlineChatModule, TrainableModule],
        doc_path: str,
        info_schema: DocInfoSchema,
        extra_desc: str = "",
    ) -> dict:
        """根据提供的字段结构(schema)从指定文档中抽取具体的关键信息值。

该方法使用大语言模型分析文档内容,根据预定义的字段结构提取相应的信息值,返回格式为 key-value 字典。

Args:
    llm (Union[OnlineChatModule, TrainableModule]): 用于文档信息抽取的大语言模型。
    doc_path (str): 要分析的文档路径。
    info_schema (DocInfoSchema): 字段结构定义,包含需要提取的字段信息。
    extra_desc (str, optional): 额外的描述信息,用于指导信息抽取。默认为空字符串。

Returns:
    dict: 提取出的关键信息字典,键为字段名,值为对应的信息值。
"""
        extraction_query = self._gen_extraction_query(doc_path, info_schema, extra_desc)
        response = llm(extraction_query)
        info: dict = self._extract_kws_value_from_response(response)
        info: dict = self._format_info_by_schema(info, info_schema)
        return info

lazyllm.tools.rag.doc_to_db.DocToDbProcessor

用于将文档信息抽取并导出到数据库中。

该类通过分析文档主题、抽取字段结构、从文档中提取关键信息,并将其保存至数据库表中。

Parameters:

  • sql_manager (SqlManager) –

    数据库管理模块。

  • doc_table_name (str, default: 'lazyllm_doc_elements' ) –

    存储文档字段的数据库表名,默认为lazyllm_doc_elements

Note
  • 如果表已存在,会自动检测并避免重复创建。
  • 如果你希望重置字段结构,使用 reset_doc_info_schema 方法。
Source code in lazyllm/tools/rag/doc_to_db/doc_processor.py
class DocToDbProcessor:
    """用于将文档信息抽取并导出到数据库中。

该类通过分析文档主题、抽取字段结构、从文档中提取关键信息,并将其保存至数据库表中。

Args:
    sql_manager (SqlManager): 数据库管理模块。
    doc_table_name (str): 存储文档字段的数据库表名,默认为`lazyllm_doc_elements`。

Note:
    - 如果表已存在,会自动检测并避免重复创建。
    - 如果你希望重置字段结构,使用 `reset_doc_info_schema` 方法。
"""

    DB_TYPE_MAP = {
        "int": sqlalchemy.Integer,
        "text": sqlalchemy.Text,
        "float": sqlalchemy.Float,
    }
    UUID_COL_NAME = "lazyllm_uuid"
    CREATED_AT_COL_NAME = "lazyllm_created_at"
    DOC_PATH_COL_NAME = "lazyllm_doc_path"

    def __init__(self, sql_manager: SqlManager, doc_table_name="lazyllm_doc_elements"):
        self._doc_genre_analyser = DocGenreAnalyser()
        self._doc_info_schema_analyser = DocInfoSchemaAnalyser(maximum_doc_num=2)
        self._doc_info_extractor = DocInfoExtractor()
        self._sql_manager = sql_manager
        self._doc_info_schema: DocInfoSchema = None
        self._doc_table_name = doc_table_name
        self._table_class = None
        all_table_names = set(self._sql_manager.get_all_tables()) if sql_manager else {}
        # If doc_table exists, then desc_table must exist as well
        if self._doc_table_name in all_table_names:
            assert (
                LazyllmDocTableDesc.__tablename__ in all_table_names
            ), "LazyllmDocTableDesc table not found in database"
        # Create desc table for totally new database
        if sql_manager and LazyllmDocTableDesc.__tablename__ not in all_table_names:
            self._sql_manager.create_table(LazyllmDocTableDesc)

    @property
    def doc_info_schema(self):
        return self._doc_info_schema

    @property
    def doc_table_name(self):
        return self._doc_table_name

    @doc_table_name.setter
    def doc_table_name(self, doc_table_name: str):
        raise NotImplementedError("Invalid to change table name")

    @property
    def sql_manager(self):
        return self._sql_manager

    @sql_manager.setter
    def sql_manager(self, sql_manager: SqlManager):
        self._sql_manager = sql_manager

    @doc_info_schema.setter
    def doc_info_schema(self, doc_info_schema: DocInfoSchema):
        raise NotImplementedError("As it'a dangerous operation, please use reset_doc_info_schema instead")

    def _save_description_to_db(self, doc_info_schema: DocInfoSchema):
        assert self._sql_manager is not None, "sqlManager is not initialized"
        json_data = json.dumps(doc_info_schema)  # 直接存储为 JSON(如果数据库支持)
        with self._sql_manager.get_session() as session:
            existing = session.query(LazyllmDocTableDesc).filter_by(id=1).first()
            if existing:
                # 更新现有记录
                existing.desc = json_data
            else:
                # 插入新记录
                new_desc = LazyllmDocTableDesc(id=1, desc=json_data)
                session.add(new_desc)
            session.commit()

    def _clear_table_orm(self, drop_doc_table=True):
        if self._table_class is not None:
            if drop_doc_table:
                self._sql_manager.drop_table(self._table_class)
            TableBase.metadata.remove(self._table_class.__table__)
            TableBase.registry._dispose_cls(self._table_class)
            del self._table_class
            self._table_class = None

    def clear(self):
        self._clear_table_orm()
        self._table_class = None
        self._doc_info_schema = None

    # Alert, reset_doc_info_schema will drop old result in db
    def _reset_doc_info_schema(self, doc_info_schema: DocInfoSchema, recreate_doc_table=True):
        assert isinstance(doc_info_schema, list)
        self._save_description_to_db(doc_info_schema)
        self._clear_table_orm(drop_doc_table=recreate_doc_table)
        for schema_item in doc_info_schema:
            is_success, err_msg = validate_schema_item(schema_item, DocInfoSchemaItem)
            assert is_success, err_msg
        self._doc_info_schema = doc_info_schema
        attrs = {"__tablename__": self._doc_table_name, "__table_args__": {"extend_existing": True}}
        # use uuid as primary key
        attrs[self.UUID_COL_NAME] = sqlalchemy.Column(sqlalchemy.String(36), primary_key=True)
        attrs[self.CREATED_AT_COL_NAME] = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False)
        attrs[self.DOC_PATH_COL_NAME] = sqlalchemy.Column(
            sqlalchemy.Text, nullable=False, primary_key=False, index=True
        )
        for schema_item in doc_info_schema:
            real_type = self.DB_TYPE_MAP.get(schema_item["type"].lower(), sqlalchemy.Text)
            attrs[schema_item["key"]] = sqlalchemy.Column(real_type, nullable=True, primary_key=False)
        self._table_class = type(self._doc_table_name.capitalize(), (TableBase,), attrs)
        if recreate_doc_table:
            # After drop_table, create table in db
            db_result = self._sql_manager.create_table(self._table_class)
            if db_result.status != DBStatus.SUCCESS:
                lazyllm.LOG.warning(f"Create table failed: {db_result.detail}")
                self.clear()

    def analyze_info_schema_by_llm(
        self, llm: Union[OnlineChatModule, TrainableModule], doc_paths: List[str], doc_topic: str = ""
    ) -> DocInfoSchema:
        """使用大语言模型从文档节点中推断数据库信息结构。

Args:
    nodes (list[DocNode]): 文档节点列表。

Returns:
    dict: 结构化信息模式,包含表名、字段、关系等信息。
"""
        assert len(doc_paths) > 0, "doc_paths should not be empty"
        if not doc_topic:
            doc_topic = self._doc_genre_analyser.analyse_doc_genre(llm, doc_paths[0])
            if doc_topic == "":
                raise ValueError("Failed to detect doc type")
        return self._doc_info_schema_analyser.analyse_info_schema(llm, doc_topic, doc_paths)

    def extract_info_from_docs(
        self, llm: Union[OnlineChatModule, TrainableModule], doc_paths: List[str], extra_desc: str = ""
    ) -> List[dict]:
        """从文档中提取结构化数据库信息。

该函数使用嵌入和检索技术,在提供的文档中获取数据库相关的文本片段,用于后续模式生成。

Args:
    docs (list[DocNode]): 输入文档列表。
    num_nodes (int): 要提取的片段数量,默认为10。

Returns:
    list[DocNode]: 提取出的相关文档片段。
"""
        existent_doc_paths = self._list_existent_doc_paths_in_db(doc_paths)
        # skip docs already in db
        doc_paths = list(set(doc_paths) - set(existent_doc_paths))
        info_dicts = []
        for doc_path in doc_paths:
            kws_value = self._doc_info_extractor.extract_doc_info(llm, doc_path, self._doc_info_schema, extra_desc)
            if kws_value:
                kws_value[self.DOC_PATH_COL_NAME] = str(doc_path)
                info_dicts.append(kws_value)
            else:
                lazyllm.LOG.warning(f"Extract kws value failed for {doc_path}")
        return info_dicts

    def export_info_to_db(self, info_dicts: List[dict]):
        # Generate uuid explicitly because SQLite doesn't support auto gen uuid
        new_values = []
        for kws_value in info_dicts:
            if kws_value:
                kws_value[self.UUID_COL_NAME] = str(uuid.uuid4())
                kws_value[self.CREATED_AT_COL_NAME] = datetime.now()
                new_values.append(kws_value)
        db_result = self._sql_manager.insert_values(self._doc_table_name, new_values)
        if db_result.status != DBStatus.SUCCESS:
            raise ValueError(f"Insert values failed: {db_result.detail}")

    def _list_existent_doc_paths_in_db(self, doc_paths: list[str]) -> List[str]:
        doc_paths = [str(ele) for ele in doc_paths]
        with self._sql_manager.get_session() as session:
            stmt = sqlalchemy.select(getattr(self._table_class, self.DOC_PATH_COL_NAME)).where(
                getattr(self._table_class, self.DOC_PATH_COL_NAME).in_(doc_paths)
            )
            result = session.execute(stmt).fetchall()
            return [ele[0] for ele in result]

analyze_info_schema_by_llm(llm, doc_paths, doc_topic='')

使用大语言模型从文档节点中推断数据库信息结构。

Parameters:

  • nodes (list[DocNode]) –

    文档节点列表。

Returns:

  • dict ( DocInfoSchema ) –

    结构化信息模式,包含表名、字段、关系等信息。

Source code in lazyllm/tools/rag/doc_to_db/doc_processor.py
    def analyze_info_schema_by_llm(
        self, llm: Union[OnlineChatModule, TrainableModule], doc_paths: List[str], doc_topic: str = ""
    ) -> DocInfoSchema:
        """使用大语言模型从文档节点中推断数据库信息结构。

Args:
    nodes (list[DocNode]): 文档节点列表。

Returns:
    dict: 结构化信息模式,包含表名、字段、关系等信息。
"""
        assert len(doc_paths) > 0, "doc_paths should not be empty"
        if not doc_topic:
            doc_topic = self._doc_genre_analyser.analyse_doc_genre(llm, doc_paths[0])
            if doc_topic == "":
                raise ValueError("Failed to detect doc type")
        return self._doc_info_schema_analyser.analyse_info_schema(llm, doc_topic, doc_paths)

extract_info_from_docs(llm, doc_paths, extra_desc='')

从文档中提取结构化数据库信息。

该函数使用嵌入和检索技术,在提供的文档中获取数据库相关的文本片段,用于后续模式生成。

Parameters:

  • docs (list[DocNode]) –

    输入文档列表。

  • num_nodes (int) –

    要提取的片段数量,默认为10。

Returns:

  • List[dict]

    list[DocNode]: 提取出的相关文档片段。

Source code in lazyllm/tools/rag/doc_to_db/doc_processor.py
    def extract_info_from_docs(
        self, llm: Union[OnlineChatModule, TrainableModule], doc_paths: List[str], extra_desc: str = ""
    ) -> List[dict]:
        """从文档中提取结构化数据库信息。

该函数使用嵌入和检索技术,在提供的文档中获取数据库相关的文本片段,用于后续模式生成。

Args:
    docs (list[DocNode]): 输入文档列表。
    num_nodes (int): 要提取的片段数量,默认为10。

Returns:
    list[DocNode]: 提取出的相关文档片段。
"""
        existent_doc_paths = self._list_existent_doc_paths_in_db(doc_paths)
        # skip docs already in db
        doc_paths = list(set(doc_paths) - set(existent_doc_paths))
        info_dicts = []
        for doc_path in doc_paths:
            kws_value = self._doc_info_extractor.extract_doc_info(llm, doc_path, self._doc_info_schema, extra_desc)
            if kws_value:
                kws_value[self.DOC_PATH_COL_NAME] = str(doc_path)
                info_dicts.append(kws_value)
            else:
                lazyllm.LOG.warning(f"Extract kws value failed for {doc_path}")
        return info_dicts

lazyllm.tools.rag.doc_to_db.extract_db_schema_from_files(file_paths, llm)

给定文档路径和LLM模型,提取文档结构信息。

Parameters:

Returns:

  • DocInfoSchema ( DocInfoSchema ) –

    提取出的字段结构描述。

Examples:

>>> import lazyllm
>>> from lazyllm.components.document_to_db import extract_db_schema_from_files
>>> llm = lazyllm.OnlineChatModule()
>>> file_paths = ["doc1.pdf", "doc2.pdf"]
>>> schema = extract_db_schema_from_files(file_paths, llm)
>>> print(schema)
Source code in lazyllm/tools/rag/doc_to_db/doc_processor.py
def extract_db_schema_from_files(file_paths: List[str], llm: Union[OnlineChatModule, TrainableModule]) -> DocInfoSchema:
    """给定文档路径和LLM模型,提取文档结构信息。

Args:
    file_paths (List[str]): 要分析的文档路径。
    llm (Union[OnlineChatModule, TrainableModule]): 支持聊天的模型模块。

Returns:
    DocInfoSchema: 提取出的字段结构描述。


Examples:
    >>> import lazyllm
    >>> from lazyllm.components.document_to_db import extract_db_schema_from_files
    >>> llm = lazyllm.OnlineChatModule()
    >>> file_paths = ["doc1.pdf", "doc2.pdf"]
    >>> schema = extract_db_schema_from_files(file_paths, llm)
    >>> print(schema)
    """
    return DocToDbProcessor(sql_manager=None).analyze_info_schema_by_llm(llm, file_paths)

lazyllm.tools.rag.readers.DocxReader

Bases: LazyLLMReaderBase

docx格式文件解析器,从 .docx 文件中读取文本内容并封装为文档节点(DocNode)列表。

Parameters:

  • file (Path) –

    .docx 文件路径。

  • fs (Optional[AbstractFileSystem]) –

    可选的文件系统对象,支持自定义读取方式。

Returns:

  • List[DocNode]: 包含文档中所有文本内容的节点列表。

Source code in lazyllm/tools/rag/readers/docxReader.py
class DocxReader(LazyLLMReaderBase):
    """docx格式文件解析器,从 `.docx` 文件中读取文本内容并封装为文档节点(DocNode)列表。

Args:
    file (Path): `.docx` 文件路径。
    fs (Optional[AbstractFileSystem]): 可选的文件系统对象,支持自定义读取方式。

Returns:
    List[DocNode]: 包含文档中所有文本内容的节点列表。
"""
    def _load_data(self, file: Path, fs: Optional[AbstractFileSystem] = None) -> List[DocNode]:
        if not isinstance(file, Path): file = Path(file)

        if fs:
            with fs.open(file) as f:
                text = docx2txt.process(f)
        else:
            text = docx2txt.process(file)

        return [DocNode(text=text)]

lazyllm.tools.rag.readers.EpubReader

Bases: LazyLLMReaderBase

用于读取 .epub 格式电子书的文件读取器。

继承自 LazyLLMReaderBase,只需实现 _load_data 方法,即可通过 Document 组件自动加载 .epub 文件中的内容。

注意:当前版本不支持通过 fsspec 文件系统(如远程路径)加载 epub 文件,若提供 fs 参数,将回退到本地文件读取。

Returns:

  • List[DocNode]: 所有章节内容合并后的文本节点列表。

Source code in lazyllm/tools/rag/readers/epubReader.py
class EpubReader(LazyLLMReaderBase):
    """用于读取 `.epub` 格式电子书的文件读取器。

继承自 `LazyLLMReaderBase`,只需实现 `_load_data` 方法,即可通过 `Document` 组件自动加载 `.epub` 文件中的内容。

注意:当前版本不支持通过 fsspec 文件系统(如远程路径)加载 epub 文件,若提供 `fs` 参数,将回退到本地文件读取。

Returns:
    List[DocNode]: 所有章节内容合并后的文本节点列表。
"""
    def _load_data(self, file: Path, fs: Optional[AbstractFileSystem] = None) -> List[DocNode]:
        if not isinstance(file, Path): file = Path(file)

        if fs:
            LOG.warning("fs was specified but EpubReader doesn't support loading from "
                        "fsspec filesystems. Will load from local filesystem instead.")

        text_list = []

        spec = importlib.util.find_spec("ebooklib.epub")
        if spec is None:
            raise ImportError(
                "Please install ebooklib to use ebooklib module. "
                "You can install it with `pip install ebooklib`"
            )
        epub_module = importlib.util.module_from_spec(spec)
        spec.loader.exec_module(epub_module)

        book = epub_module.read_epub(file, options={"ignore_ncs": True})

        for item in book.get_items():
            if item.get_type() == ebooklib.ITEM_DOCUMENT:
                text_list.append(html2text.html2text(item.get_content().decode("utf-8")))
        text = "\n".join(text_list)
        return [DocNode(text=text)]

lazyllm.tools.rag.readers.HWPReader

Bases: LazyLLMReaderBase

HWP文件解析器,支持从本地文件系统读取 HWP 文件。它会从文档中提取正文部分的文本内容,返回 DocNode 列表。

HWP 是一种专有的二进制格式,主要在韩国使用。由于格式封闭,因此只能解析部分内容(如文本段落),但对常规文本提取已经足够使用。

Parameters:

  • return_trace (bool, default: True ) –

    是否启用 trace 日志记录,默认为 True

Source code in lazyllm/tools/rag/readers/hwpReader.py
class HWPReader(LazyLLMReaderBase):
    """HWP文件解析器,支持从本地文件系统读取 HWP 文件。它会从文档中提取正文部分的文本内容,返回 DocNode 列表。

HWP 是一种专有的二进制格式,主要在韩国使用。由于格式封闭,因此只能解析部分内容(如文本段落),但对常规文本提取已经足够使用。

Args:
    return_trace (bool): 是否启用 trace 日志记录,默认为 ``True``。
"""
    def __init__(self, return_trace: bool = True) -> None:
        super().__init__(return_trace=return_trace)
        self._FILE_HEADER_SECTION = "FileHeader"
        self._HWP_SUMMARY_SECTION = "\x05HwpSummaryInformation"
        self._SECTION_NAME_LENGTH = len("Section")
        self._BODYTEXT_SECTION = "BodyText"
        self._HWP_TEXT_TAGS = [67]
        self._text = ""

    def _load_data(self, file: Path, fs: Optional[AbstractFileSystem] = None) -> List[DocNode]:
        if fs:
            LOG.warning("fs was specified but HWPReader doesn't support loading from "
                        "fsspec filesystems. Will load from local filesystem instead.")

        if not isinstance(file, Path): file = Path(file)

        load_file = olefile.OleFileIO(file)
        file_dir = load_file.listdir()
        if self._is_valid(file_dir) is False: raise Exception("Not Valid HwpFile")

        result_text = self._get_text(load_file, file_dir)
        return [DocNode(text=result_text)]

    def _is_valid(self, dirs: List[str]) -> bool:
        if [self._FILE_HEADER_SECTION] not in dirs: return False
        return [self._HWP_SUMMARY_SECTION] in dirs

    def _get_text(self, load_file: Any, file_dirs: List[str]) -> str:
        sections = self._get_body_sections(file_dirs)
        text = ""
        for section in sections:
            text += self._get_text_from_section(load_file, section)
            text += "\n"

        self._text = text
        return self._text

    def _get_body_sections(self, dirs: List[str]) -> List[str]:
        m = []
        for d in dirs:
            if d[0] == self._BODYTEXT_SECTION:
                m.append(int(d[1][self._SECTION_NAME_LENGTH:]))

        return ["BodyText/Section" + str(x) for x in sorted(m)]

    def _is_compressed(self, load_file: Any) -> bool:
        header = load_file.openstream("FileHeader")
        header_data = header.read()
        return (header_data[36] & 1) == 1

    def _get_text_from_section(self, load_file: Any, section: str) -> str:
        bodytext = load_file.openstream(section)
        data = bodytext.read()

        unpacked_data = (zlib.decompress(data, -15) if self._is_compressed(load_file) else data)
        size = len(unpacked_data)

        i = 0
        text = ""
        while i < size:
            header = struct.unpack_from("<I", unpacked_data, i)[0]
            rec_type = header & 0x3FF
            (header >> 10) & 0x3FF
            rec_len = (header >> 20) & 0xFFF

            if rec_type in self._HWP_TEXT_TAGS:
                rec_data = unpacked_data[i + 4: i + 4 + rec_len]
                text += rec_data.decode("utf-16")
                text += "\n"

            i += 4 + rec_len
        return text

lazyllm.tools.rag.readers.ImageReader

Bases: LazyLLMReaderBase

用于从图片文件中读取内容的模块。支持保留图片、解析图片中的文本(基于OCR或预训练视觉模型),并返回文本和图片路径的节点列表。

Parameters:

  • parser_config (Optional[Dict], default: None ) –

    解析器配置,包含模型和处理器,默认为 None。当设置 parse_text=True 且 parser_config=None 时,会自动根据 text_type 加载相应模型。

  • keep_image (bool, default: False ) –

    是否保留图片的 base64 编码,默认为 False。

  • parse_text (bool, default: False ) –

    是否解析图片中的文本,默认为 False。

  • text_type (str, default: 'text' ) –

    解析文本的类型,支持 text(默认)和 plain_text。当为 plain_text 时,使用 pytesseract 进行OCR;否则使用预训练视觉编码解码模型。

  • pytesseract_model_kwargs (Optional[Dict], default: None ) –

    传递给 pytesseract OCR 的可选参数,默认为空字典。

  • return_trace (bool, default: True ) –

    是否记录处理过程的 trace,默认为 True。

Source code in lazyllm/tools/rag/readers/imageReader.py
class ImageReader(LazyLLMReaderBase):
    """用于从图片文件中读取内容的模块。支持保留图片、解析图片中的文本(基于OCR或预训练视觉模型),并返回文本和图片路径的节点列表。

Args:
    parser_config (Optional[Dict]): 解析器配置,包含模型和处理器,默认为 None。当设置 parse_text=True 且 parser_config=None 时,会自动根据 text_type 加载相应模型。
    keep_image (bool): 是否保留图片的 base64 编码,默认为 False。
    parse_text (bool): 是否解析图片中的文本,默认为 False。
    text_type (str): 解析文本的类型,支持 ``text``(默认)和 ``plain_text``。当为 ``plain_text`` 时,使用 pytesseract 进行OCR;否则使用预训练视觉编码解码模型。
    pytesseract_model_kwargs (Optional[Dict]): 传递给 pytesseract OCR 的可选参数,默认为空字典。
    return_trace (bool): 是否记录处理过程的 trace,默认为 True。
"""
    def __init__(self, parser_config: Optional[Dict] = None, keep_image: bool = False, parse_text: bool = False,
                 text_type: str = "text", pytesseract_model_kwargs: Optional[Dict] = None,
                 return_trace: bool = True) -> None:
        super().__init__(return_trace=return_trace)
        self._text_type = text_type
        if parser_config is None and parse_text:
            if text_type == "plain_text":
                try:
                    import pytesseract
                except ImportError:
                    raise ImportError("Please install extra dependencies that are required for the ImageReader "
                                      "when text_type is 'plain_text': `pip install pytesseract`")

                processor = None
                model = pytesseract
            else:
                thirdparty.check_packages(["sentencepiece", "torch", "transformers"])

                processor = tf.DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2")
                model = tf.VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2")
            parser_config = {'processor': processor, 'model': model}

        self._parser_config = parser_config
        self._keep_image = keep_image
        self._parse_text = parse_text
        self._pytesseract_model_kwargs = pytesseract_model_kwargs or {}

    def _load_data(self, file: Path, fs: Optional[AbstractFileSystem] = None) -> List[ImageDocNode]:
        if not isinstance(file, Path): file = Path(file)

        if fs:
            with fs.open(path=file) as f:
                image = PIL.Image.open(f.read())
        else:
            image = PIL.Image.open(file)

        if image.mode != "RGB": image = image.convert("RGB")

        image_str: Optional[str] = None  # noqa
        if self._keep_image: image_str = img_2_b64(image)  # noqa

        text_str: str = ""
        if self._parse_text:
            assert self._parser_config is not None
            model = self._parser_config["model"]
            processor = self._parser_config["processor"]

            if processor:
                device = infer_torch_device()
                model.to(device)

                task_prompt = "<s_cord-v2>"
                decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False,
                                                        return_tensors='pt').input_ids
                pixel_values = processor(image, return_tensors='pt').pixel_values

                output = model.generate(pixel_values.to(device), decoder_input_ids=decoder_input_ids.to(device),
                                        max_length=model.decoder.config.max_position_embeddings, early_stopping=True,
                                        pad_token_id=processor.tokenizer.pad_token_id,
                                        eos_token_id=processor.tokenizer.eos_token_id, use_cache=True, num_beams=3,
                                        bad_words_ids=[[processor.tokenizer.unk_token_id]],
                                        return_dict_in_generate=True)

                sequence = processor.batch_decode(output.sequences)[0]
                sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
                text_str = re.sub(r"<.*?>", "", sequence, count=1).strip()
            else:
                import pytesseract

                model = cast(pytesseract, self._parser_config['model'])
                text_str = model.image_to_string(image, **self._pytesseract_model_kwargs)

        return [ImageDocNode(text=text_str, image_path=str(file))]

lazyllm.tools.rag.readers.IPYNBReader

Bases: LazyLLMReaderBase

用于读取和解析 Jupyter Notebook (.ipynb) 文件的模块。将 notebook 转换成脚本文本后,按代码单元划分为多个文档节点,或合并为单一文本节点。

Parameters:

  • parser_config (Optional[Dict], default: None ) –

    预留的解析器配置参数,当前未使用,默认为 None。

  • concatenate (bool, default: False ) –

    是否将所有代码单元合并成一个整体文本节点,默认为 False,即分割为多个节点。

  • return_trace (bool, default: True ) –

    是否记录处理过程的 trace,默认为 True。

Source code in lazyllm/tools/rag/readers/ipynbReader.py
class IPYNBReader(LazyLLMReaderBase):
    """用于读取和解析 Jupyter Notebook (.ipynb) 文件的模块。将 notebook 转换成脚本文本后,按代码单元划分为多个文档节点,或合并为单一文本节点。

Args:
    parser_config (Optional[Dict]): 预留的解析器配置参数,当前未使用,默认为 None。
    concatenate (bool): 是否将所有代码单元合并成一个整体文本节点,默认为 False,即分割为多个节点。
    return_trace (bool): 是否记录处理过程的 trace,默认为 True。
"""
    def __init__(self, parser_config: Optional[Dict] = None, concatenate: bool = False, return_trace: bool = True):
        super().__init__(return_trace=return_trace)
        self._parser_config = parser_config
        self._concatenate = concatenate

    def _load_data(self, file: Path, fs: Optional[AbstractFileSystem] = None) -> List[DocNode]:
        if not isinstance(file, Path): file = Path(file)

        if file.name.endswith(".ipynb"):
            try:
                import nbconvert
            except ImportError:
                raise ImportError("Please install nbconvert `pip install nbconvert`")

        if fs:
            with fs.open(file, encoding='utf-8') as f:
                doc_str = nbconvert.exporters.ScriptExporter().from_file(f)[0]
        else:
            doc_str = nbconvert.exporters.ScriptExporter().from_file(file)[0]

        splits = re.split(r"In\[\d+\]:", doc_str)
        splits.pop(0)

        if self._concatenate: docs = [DocNode(text="\n\n".join(splits))]
        else: docs = [DocNode(text=s) for s in splits]

        return docs

lazyllm.tools.rag.readers.MagicPDFReader

用于通过 MagicPDF 服务解析 PDF 文件内容的模块。支持上传文件或通过 URL 方式调用解析接口,解析结果经过回调函数处理成文档节点列表。

Parameters:

  • magic_url (str) –

    MagicPDF 服务的接口 URL。

  • callback (Optional[Callable[[List[dict], Path, dict], List[DocNode]]], default: None ) –

    解析结果回调函数,接收解析元素列表、文件路径及额外信息,返回文档节点列表。默认将所有文本合并为一个节点。

  • upload_mode (bool, default: False ) –

    是否采用文件上传模式调用接口,默认为 False,即通过 JSON 请求文件路径。

Source code in lazyllm/tools/rag/readers/magic_pdf_reader.py
class MagicPDFReader:
    """用于通过 MagicPDF 服务解析 PDF 文件内容的模块。支持上传文件或通过 URL 方式调用解析接口,解析结果经过回调函数处理成文档节点列表。

Args:
    magic_url (str): MagicPDF 服务的接口 URL。
    callback (Optional[Callable[[List[dict], Path, dict], List[DocNode]]]): 解析结果回调函数,接收解析元素列表、文件路径及额外信息,返回文档节点列表。默认将所有文本合并为一个节点。
    upload_mode (bool): 是否采用文件上传模式调用接口,默认为 False,即通过 JSON 请求文件路径。
"""

    def __init__(self, magic_url, callback: Optional[Callable[[List[dict], Path, dict], List[DocNode]]] = None,
                 upload_mode: bool = False):
        self._magic_url = magic_url
        self._upload_mode = upload_mode
        if callback is not None:
            self._callback = callback
        else:
            def default_callback(elements: List[dict], file: Path, extra_info: Optional[Dict] = None) -> List[DocNode]:
                text_chunks = [el["text"] for el in elements if "text" in el]
                return [DocNode(text="\n".join(text_chunks), metadata={"file_name": file.name})]
            self._callback = default_callback

    def __call__(self, file: Path, **kwargs) -> List[DocNode]:
        try:
            return self._load_data(file, **kwargs)
        except Exception as e:
            LOG.error(f"[MagicPDFReader] Error loading data from {file}: {e}")
            return []

    def _load_data(self, file: Path, extra_info: Optional[Dict] = None, **kwargs) -> List[DocNode]:
        if isinstance(file, str):
            file = Path(file)
        if self._upload_mode:
            elements = self._upload_parse_pdf_elements(file)
        else:
            elements = self._parse_pdf_elements(file)
        docs: List[DocNode] = self._callback(elements, file, extra_info)
        return docs

    def _parse_pdf_elements(self, pdf_path: Path) -> List[dict]:
        payload = {"files": [str(pdf_path)], "reserve_image": True}
        try:
            response = requests.post(self._magic_url, json=payload)
            response.raise_for_status()
            res = response.json()
            if not isinstance(res, list) or not res:
                LOG.info(f"[MagicPDFReader] No elements found in PDF: {pdf_path}")
                return []
        except requests.exceptions.RequestException as e:
            LOG.error(f"[MagicPDFReader] POST failed: {e}")
            return []
        return self._extract_content_blocks(res[0])

    def _upload_parse_pdf_elements(self, pdf_path: Path) -> List[dict]:
        try:
            with open(pdf_path, "rb") as f:
                files = {'file': (os.path.basename(pdf_path), f)}
                response = requests.post(self._magic_url, files=files)
                response.raise_for_status()
                res = response.json()
                if not isinstance(res, list) or not res:
                    LOG.info(f"[MagicPDFReader] No elements found in PDF: {pdf_path}")
                    return []
        except requests.exceptions.RequestException as e:
            LOG.error(f"[MagicPDFReader] POST failed: {e}")
            return []
        return self._extract_content_blocks(res[0])

    def _extract_content_blocks(self, content_list) -> List[dict]:  # noqa: C901
        blocks = []
        cur_title = ""
        cur_level = -1
        for content in content_list:
            block = {}
            block["bbox"] = content["bbox"]
            block["lines"] = content["lines"] if 'lines' in content else []
            for line in block['lines']:
                line['content'] = self._clean_content(line['content'])
            if content["type"] == "text":
                content["text"] = self._clean_content(content["text"]).strip()
                if not content["text"]:
                    continue
                if "text_level" in content:
                    if cur_title and content["text_level"] > cur_level:
                        content["title"] = cur_title
                    cur_title = content["text"]
                    cur_level = content["text_level"]
                else:
                    if cur_title:
                        content["title"] = cur_title
                block = copy.deepcopy(content)
                block["page"] = content["page_idx"]
                del block["page_idx"]
                blocks.append(block)
            elif content["type"] == "image":
                if not content["img_path"]:
                    continue
                block["type"] = content["type"]
                block["page"] = content["page_idx"]
                block["image_path"] = os.path.basename(content["img_path"])
                block['img_caption'] = self._clean_content(content['img_caption'])
                block['img_footnote'] = self._clean_content(content['img_footnote'])
                if cur_title:
                    block["title"] = cur_title
                img_title = block["img_caption"][0] if len(block["img_caption"]) > 0 else ""
                block["text"] = f"![{img_title}]({block['image_path']})"
                blocks.append(block)
            elif content["type"] == "table":
                block["type"] = content["type"]
                block["page"] = content["page_idx"]
                if self.extract_table:
                    block["text"] = self._html_table_to_markdown(self._clean_content(content["table_body"])
                                                                 ) if "table_body" in content else ""
                else:
                    block['image_path'] = os.path.basename(content['img_path'])
                if cur_title:
                    block["title"] = cur_title
                block['table_caption'] = self._clean_content(content['table_caption'])
                block['table_footnote'] = self._clean_content(content['table_footnote'])
                blocks.append(block)
        return blocks

    def _clean_content(self, content) -> str:
        if isinstance(content, str):
            content = content.encode("utf-8", "replace").decode("utf-8")
            return unicodedata.normalize("NFKC", content)
        if isinstance(content, list):
            return [self._clean_content(t) for t in content]
        return content

    def _html_table_to_markdown(self, html_table) -> str:  # noqa: C901
        try:
            soup = BeautifulSoup(html_table.strip(), 'html.parser')
            table = soup.find('table')
            if not table:
                raise ValueError("No <table> found in the HTML.")

            rows = []
            max_cols = 0

            for row in table.find_all('tr'):
                cells = []
                for cell in row.find_all(['td', 'th']):
                    rowspan = int(cell.get('rowspan', 1))
                    colspan = int(cell.get('colspan', 1))
                    text = cell.get_text(strip=True)

                    for _ in range(colspan):
                        cells.append({'text': text, 'rowspan': rowspan})
                rows.append(cells)
                max_cols = max(max_cols, len(cells))

            expanded_rows = []
            rowspan_tracker = [0] * max_cols
            for row in rows:
                expanded_row = []
                col_idx = 0
                for cell in row:
                    while col_idx < max_cols and rowspan_tracker[col_idx] > 0:
                        expanded_row.append(None)
                        rowspan_tracker[col_idx] -= 1
                        col_idx += 1

                    expanded_row.append(cell['text'])
                    if cell['rowspan'] > 1:
                        rowspan_tracker[col_idx] = cell['rowspan'] - 1
                    col_idx += 1

                while col_idx < max_cols:
                    if rowspan_tracker[col_idx] > 0:
                        expanded_row.append(None)
                        rowspan_tracker[col_idx] -= 1
                    else:
                        expanded_row.append("")
                    col_idx += 1

                expanded_rows.append(expanded_row)

            markdown = ''
            if not expanded_rows:
                return ""

            headers = expanded_rows[0]
            body_rows = expanded_rows[1:]
            if headers:
                markdown += '| ' + ' | '.join(h if h else '' for h in headers) + ' |\n'
                markdown += '| ' + ' | '.join(['-' * (len(h) if h else 3) for h in headers]) + ' |\n'
            for row in body_rows:
                markdown += '| ' + ' | '.join(cell if cell else '' for cell in row) + ' |\n'

            return markdown

        except Exception as e:
            LOG.error(f"Error parsing table: {e}")
            return ''

lazyllm.tools.rag.readers.MarkdownReader

Bases: LazyLLMReaderBase

用于读取和解析 Markdown 文件的模块。支持去除超链接和图片,按标题和内容将 Markdown 划分成若干文本段落节点。

Parameters:

  • remove_hyperlinks (bool, default: True ) –

    是否移除超链接,默认 True。

  • remove_images (bool, default: True ) –

    是否移除图片标记,默认 True。

  • return_trace (bool, default: True ) –

    是否记录处理过程的 trace,默认为 True。

Source code in lazyllm/tools/rag/readers/markdownReader.py
class MarkdownReader(LazyLLMReaderBase):
    """用于读取和解析 Markdown 文件的模块。支持去除超链接和图片,按标题和内容将 Markdown 划分成若干文本段落节点。

Args:
    remove_hyperlinks (bool): 是否移除超链接,默认 True。
    remove_images (bool): 是否移除图片标记,默认 True。
    return_trace (bool): 是否记录处理过程的 trace,默认为 True。
"""
    def __init__(self, remove_hyperlinks: bool = True, remove_images: bool = True, return_trace: bool = True) -> None:
        super().__init__(return_trace=return_trace)
        self._remove_hyperlinks = remove_hyperlinks
        self._remove_images = remove_images

    def _markdown_to_tups(self, markdown_text: str) -> List[Tuple[Optional[str], str]]:
        markdown_tups: List[Tuple[Optional[str], str]] = []
        lines = markdown_text.split("\n")

        current_header = None
        current_lines = []
        in_code_block = False

        for line in lines:
            if line.startswith("```"): in_code_block = not in_code_block

            header_match = re.match(r"^#+\s", line)
            if not in_code_block and header_match:
                if current_header is not None or len(current_lines) > 0:
                    markdown_tups.append((current_header, "\n".join(current_lines)))
                current_header = line
                current_lines.clear()
            else:
                current_lines.append(line)

        markdown_tups.append((current_header, "\n".join(current_lines)))
        return [(key if key is None else re.sub(r"#", "", key).strip(), re.sub(r"<.*?>", "", value),)
                for key, value in markdown_tups]

    def remove_images(self, content: str) -> str:
        """移除内容中形如 ![[...]] 的自定义图片标签。

Args:
    content (str): 输入的 markdown 内容。

Returns:
    str: 移除图片标签后的内容。
"""
        pattern = r"!{1}\[\[(.*)\]\]"
        return re.sub(pattern, "", content)

    def remove_hyperlinks(self, content: str) -> str:
        """移除 Markdown 超链接,将 [文本](链接) 转换为纯文本。

Args:
    content (str): 输入的 markdown 内容。

Returns:
    str: 移除超链接后的内容,仅保留链接文本。
"""
        pattern = r"\[(.*)\]\((.*)\)"
        return re.sub(pattern, r"\1", content)

    def _parse_tups(self, filepath: Path, errors: str = "ignore",
                    fs: Optional[AbstractFileSystem] = None) -> List[Tuple[Optional[str], str]]:
        fs = fs or LocalFileSystem()

        with fs.open(filepath, encoding="utf-8") as f:
            content = f.read().decode(encoding="utf-8")

        if self._remove_hyperlinks: content = self.remove_hyperlinks(content)
        if self._remove_images: content = self.remove_images(content)
        return self._markdown_to_tups(content)

    def _load_data(self, file: Path, fs: Optional[AbstractFileSystem] = None) -> List[DocNode]:
        if not isinstance(file, Path): file = Path(file)

        tups = self._parse_tups(file, fs=fs)
        results = [DocNode(
            content=[value if header is None else f"\n\n{header}\n{value}" for header, value in tups])]
        return results

移除 Markdown 超链接,将 文本 转换为纯文本。

Parameters:

  • content (str) –

    输入的 markdown 内容。

Returns:

  • str ( str ) –

    移除超链接后的内容,仅保留链接文本。

Source code in lazyllm/tools/rag/readers/markdownReader.py
    def remove_hyperlinks(self, content: str) -> str:
        """移除 Markdown 超链接,将 [文本](链接) 转换为纯文本。

Args:
    content (str): 输入的 markdown 内容。

Returns:
    str: 移除超链接后的内容,仅保留链接文本。
"""
        pattern = r"\[(.*)\]\((.*)\)"
        return re.sub(pattern, r"\1", content)

remove_images(content)

移除内容中形如 ![[...]] 的自定义图片标签。

Parameters:

  • content (str) –

    输入的 markdown 内容。

Returns:

  • str ( str ) –

    移除图片标签后的内容。

Source code in lazyllm/tools/rag/readers/markdownReader.py
    def remove_images(self, content: str) -> str:
        """移除内容中形如 ![[...]] 的自定义图片标签。

Args:
    content (str): 输入的 markdown 内容。

Returns:
    str: 移除图片标签后的内容。
"""
        pattern = r"!{1}\[\[(.*)\]\]"
        return re.sub(pattern, "", content)

lazyllm.tools.rag.readers.MboxReader

Bases: LazyLLMReaderBase

用于解析 Mbox 邮件存档文件的模块。读取邮件内容并格式化为文本,支持限制最大邮件数和自定义消息格式。

Parameters:

  • max_count (int, default: 0 ) –

    最大读取的邮件数量,默认 0 表示读取全部邮件。

  • message_format (str, default: DEFAULT_MESSAGE_FORMAT ) –

    邮件文本格式模板,支持使用 {_date}{_from}{_to}{_subject}{_content} 占位符。

  • return_trace (bool, default: True ) –

    是否记录处理过程的 trace,默认为 True。

Source code in lazyllm/tools/rag/readers/mboxreader.py
class MboxReader(LazyLLMReaderBase):
    """用于解析 Mbox 邮件存档文件的模块。读取邮件内容并格式化为文本,支持限制最大邮件数和自定义消息格式。

Args:
    max_count (int): 最大读取的邮件数量,默认 0 表示读取全部邮件。
    message_format (str): 邮件文本格式模板,支持使用 ``{_date}``、``{_from}``、``{_to}``、``{_subject}`` 和 ``{_content}`` 占位符。
    return_trace (bool): 是否记录处理过程的 trace,默认为 True。
"""
    DEFAULT_MESSAGE_FORMAT: str = (
        "Date: {_date}\n"
        "From: {_from}\n"
        "To: {_to}\n"
        "Subject: {_subject}\n"
        "Content: {_content}"
    )

    def __init__(self, max_count: int = 0, message_format: str = DEFAULT_MESSAGE_FORMAT,
                 return_trace: bool = True) -> None:
        try:
            from bs4 import BeautifulSoup  # noqa
        except ImportError:
            raise ImportError("`BeautifulSoup` package not found: `pip install beautifulsoup4`")

        super().__init__(return_trace=return_trace)
        self._max_count = max_count
        self._message_format = message_format

    def _load_data(self, file: Path, fs: Optional[AbstractFileSystem] = None) -> List[DocNode]:
        import mailbox
        from email.parser import BytesParser
        from email.policy import default
        from bs4 import BeautifulSoup

        if fs:
            LOG.warning("fs was specified but MboxReader doesn't support loading from "
                        "fsspec filesystems. Will load from local filesystem instead.")

        i = 0
        results: List[str] = []
        bytes_parser = BytesParser(policy=default).parse
        mbox = mailbox.mbox(file, factory=bytes_parser)

        for _, _msg in enumerate(mbox):
            try:
                msg: mailbox.mboxMessage = _msg
                if msg.is_multipart():
                    for part in msg.walk():
                        ctype = part.get_content_type()
                        cdispo = str(part.get("Content-Disposition"))
                        if ctype == "text/plain" and "attachment" not in cdispo:
                            content = part.get_payload(decode=True)
                            break
                else:
                    content = msg.get_payload(decode=True)

                soup = BeautifulSoup(content)
                stripped_content = " ".join(soup.get_text().split())
                msg_string = self._message_format.format(_date=msg["date"], _from=msg["from"], _to=msg["to"],
                                                         _subject=msg["subject"], _content=stripped_content)
                results.append(msg_string)
            except Exception as e:
                LOG.warning(f"Failed to parse message:\n{_msg}\n with exception {e}")

            i += 1
            if self._max_count > 0 and i >= self._max_count: break
        return [DocNode(text=result) for result in results]

lazyllm.tools.rag.default_index.DefaultIndex

Bases: IndexBase

\ 默认的索引实现,负责通过 embedding 和文本相似度在底层存储中查询、更新和删除文档节点。支持多种相似度度量方式,并在必要时对查询和节点进行 embedding 计算与更新。

Parameters:

  • embed (Dict[str, Callable]) –

    用于生成查询和节点 embedding 的字典,key 是 embedding 名称,value 是接收字符串返回向量的函数。

  • store (StoreBase) –

    底层存储,用于持久化和检索 DocNode 节点。

  • **kwargs

    预留扩展参数。

Source code in lazyllm/tools/rag/default_index.py
class DefaultIndex(IndexBase):
    """\ 
默认的索引实现,负责通过 embedding 和文本相似度在底层存储中查询、更新和删除文档节点。支持多种相似度度量方式,并在必要时对查询和节点进行 embedding 计算与更新。

Args:
    embed (Dict[str, Callable]): 用于生成查询和节点 embedding 的字典,key 是 embedding 名称,value 是接收字符串返回向量的函数。
    store (StoreBase): 底层存储,用于持久化和检索 DocNode 节点。
    **kwargs: 预留扩展参数。
"""
    def __init__(self, embed: Dict[str, Callable], store, **kwargs):
        self.embed = embed
        self.store = store

    @override
    def update(self, nodes: List[DocNode]) -> None:
        """\ 
根据提供的节点列表更新索引中的内容。具体行为由子类或外部实现填充(此处为空实现,需在实际使用中覆盖/扩展)。

Args:
    nodes (List[DocNode]): 需要更新(新增或替换)的文档节点列表。
"""
        pass

    @override
    def remove(self, uids: List[str], group_name: Optional[str] = None) -> None:
        """\ 
从索引中删除指定 UID 的节点,可选指定分组名称以限定作用域。当前为空实现,使用时需要补全逻辑。

Args:
    uids (List[str]): 要删除的节点唯一标识列表。
    group_name (Optional[str]): 可选的分组名称,用于限定删除范围。
"""
        pass

    @override
    def query(
        self,
        query: str,
        group_name: str,
        similarity_name: str,
        similarity_cut_off: Union[float, Dict[str, float]],
        topk: int,
        embed_keys: Optional[List[str]] = None,
        filters: Optional[Dict[str, List]] = None,
        **kwargs,
    ) -> List[DocNode]:
        """\ 
执行一次查询,支持 embedding 和文本两种模式,依据相似度函数过滤并返回符合条件的 DocNode 结果。

Args:
    query (str): 原始查询文本。
    group_name (str): 要检索的节点组名称。
    similarity_name (str): 使用的相似度度量名称,必须在 registered_similarities 中注册。
    similarity_cut_off (Union[float, Dict[str, float]]): 相似度阈值或每个 embedding 对应的阈值字典,用于过滤结果。
    topk (int): 每个相似度渠道最多保留的候选数量。
    embed_keys (Optional[List[str]]): 指定用于 embedding 的 key 列表,若为空则使用所有可用 embedding。
    filters (Optional[Dict[str, List]]): 额外的节点过滤器,应用在计算相似度前。
    **kwargs: 传递给相似度函数的额外参数。

**Returns**\n
    - list: List[DocNode]: 经过相似度计算与阈值过滤后去重的文档节点列表。
"""
        if similarity_name not in registered_similarities:
            raise ValueError(
                f"{similarity_name} not registered, please check your input. "
                f"Available options now: {registered_similarities.keys()}"
            )
        similarity_func, mode, descend = registered_similarities[similarity_name]

        nodes = self.store.get_nodes(group=group_name)
        if filters:
            nodes = generic_process_filters(nodes, filters)

        if mode == "embedding":
            assert self.embed, "Chosen similarity needs embed model."
            assert len(query) > 0, "Query should not be empty."
            if not embed_keys:
                embed_keys = list(self.embed.keys())
            query_embedding = {k: self.embed[k](query) for k in embed_keys}
            self._check_supported(similarity_name, query_embedding)
            modified_nodes = parallel_do_embedding(self.embed, embed_keys, nodes)
            self.store.update_nodes(modified_nodes)
            similarities = similarity_func(query_embedding, nodes, topk=topk, **kwargs)
        elif mode == "text":
            similarities = similarity_func(query, nodes, topk=topk, **kwargs)
        else:
            raise NotImplementedError(f"Mode {mode} is not supported.")

        if not isinstance(similarities, dict):
            results = self._filter_nodes_by_score(similarities, topk, similarity_cut_off, descend)
        else:
            results = []
            for key in (embed_keys or similarities.keys()):
                sims = similarities[key]
                sim_cut_off = similarity_cut_off if isinstance(similarity_cut_off, float) else similarity_cut_off[key]
                results.extend(self._filter_nodes_by_score(sims, topk, sim_cut_off, descend))
        results = list(set(results))
        LOG.debug(f"Retrieving query `{query}` and get results: {results}")
        return results

    def _filter_nodes_by_score(self, similarities: List[Tuple[DocNode, float]], topk: int,
                               similarity_cut_off: float, descend) -> List[DocNode]:
        similarities.sort(key=lambda x: x[1], reverse=descend)
        if topk is not None:
            similarities = similarities[:topk]

        return [node.with_sim_score(score) for node, score in similarities if score > similarity_cut_off]

    def _check_supported(self, similarity_name: str, query_embedding: Dict[str, Any]) -> None:
        if similarity_name.lower() == 'cosine':
            for k, e in query_embedding.items():
                if is_sparse(e):
                    raise NotImplementedError(f'embed `{k}`, which is sparse, is not supported.')

query(query, group_name, similarity_name, similarity_cut_off, topk, embed_keys=None, filters=None, **kwargs)

\ 执行一次查询,支持 embedding 和文本两种模式,依据相似度函数过滤并返回符合条件的 DocNode 结果。

Parameters:

  • query (str) –

    原始查询文本。

  • group_name (str) –

    要检索的节点组名称。

  • similarity_name (str) –

    使用的相似度度量名称,必须在 registered_similarities 中注册。

  • similarity_cut_off (Union[float, Dict[str, float]]) –

    相似度阈值或每个 embedding 对应的阈值字典,用于过滤结果。

  • topk (int) –

    每个相似度渠道最多保留的候选数量。

  • embed_keys (Optional[List[str]], default: None ) –

    指定用于 embedding 的 key 列表,若为空则使用所有可用 embedding。

  • filters (Optional[Dict[str, List]], default: None ) –

    额外的节点过滤器,应用在计算相似度前。

  • **kwargs

    传递给相似度函数的额外参数。

Returns

- list: List[DocNode]: 经过相似度计算与阈值过滤后去重的文档节点列表。
Source code in lazyllm/tools/rag/default_index.py
    @override
    def query(
        self,
        query: str,
        group_name: str,
        similarity_name: str,
        similarity_cut_off: Union[float, Dict[str, float]],
        topk: int,
        embed_keys: Optional[List[str]] = None,
        filters: Optional[Dict[str, List]] = None,
        **kwargs,
    ) -> List[DocNode]:
        """\ 
执行一次查询,支持 embedding 和文本两种模式,依据相似度函数过滤并返回符合条件的 DocNode 结果。

Args:
    query (str): 原始查询文本。
    group_name (str): 要检索的节点组名称。
    similarity_name (str): 使用的相似度度量名称,必须在 registered_similarities 中注册。
    similarity_cut_off (Union[float, Dict[str, float]]): 相似度阈值或每个 embedding 对应的阈值字典,用于过滤结果。
    topk (int): 每个相似度渠道最多保留的候选数量。
    embed_keys (Optional[List[str]]): 指定用于 embedding 的 key 列表,若为空则使用所有可用 embedding。
    filters (Optional[Dict[str, List]]): 额外的节点过滤器,应用在计算相似度前。
    **kwargs: 传递给相似度函数的额外参数。

**Returns**\n
    - list: List[DocNode]: 经过相似度计算与阈值过滤后去重的文档节点列表。
"""
        if similarity_name not in registered_similarities:
            raise ValueError(
                f"{similarity_name} not registered, please check your input. "
                f"Available options now: {registered_similarities.keys()}"
            )
        similarity_func, mode, descend = registered_similarities[similarity_name]

        nodes = self.store.get_nodes(group=group_name)
        if filters:
            nodes = generic_process_filters(nodes, filters)

        if mode == "embedding":
            assert self.embed, "Chosen similarity needs embed model."
            assert len(query) > 0, "Query should not be empty."
            if not embed_keys:
                embed_keys = list(self.embed.keys())
            query_embedding = {k: self.embed[k](query) for k in embed_keys}
            self._check_supported(similarity_name, query_embedding)
            modified_nodes = parallel_do_embedding(self.embed, embed_keys, nodes)
            self.store.update_nodes(modified_nodes)
            similarities = similarity_func(query_embedding, nodes, topk=topk, **kwargs)
        elif mode == "text":
            similarities = similarity_func(query, nodes, topk=topk, **kwargs)
        else:
            raise NotImplementedError(f"Mode {mode} is not supported.")

        if not isinstance(similarities, dict):
            results = self._filter_nodes_by_score(similarities, topk, similarity_cut_off, descend)
        else:
            results = []
            for key in (embed_keys or similarities.keys()):
                sims = similarities[key]
                sim_cut_off = similarity_cut_off if isinstance(similarity_cut_off, float) else similarity_cut_off[key]
                results.extend(self._filter_nodes_by_score(sims, topk, sim_cut_off, descend))
        results = list(set(results))
        LOG.debug(f"Retrieving query `{query}` and get results: {results}")
        return results

remove(uids, group_name=None)

\ 从索引中删除指定 UID 的节点,可选指定分组名称以限定作用域。当前为空实现,使用时需要补全逻辑。

Parameters:

  • uids (List[str]) –

    要删除的节点唯一标识列表。

  • group_name (Optional[str], default: None ) –

    可选的分组名称,用于限定删除范围。

Source code in lazyllm/tools/rag/default_index.py
    @override
    def remove(self, uids: List[str], group_name: Optional[str] = None) -> None:
        """\ 
从索引中删除指定 UID 的节点,可选指定分组名称以限定作用域。当前为空实现,使用时需要补全逻辑。

Args:
    uids (List[str]): 要删除的节点唯一标识列表。
    group_name (Optional[str]): 可选的分组名称,用于限定删除范围。
"""
        pass

update(nodes)

\ 根据提供的节点列表更新索引中的内容。具体行为由子类或外部实现填充(此处为空实现,需在实际使用中覆盖/扩展)。

Parameters:

  • nodes (List[DocNode]) –

    需要更新(新增或替换)的文档节点列表。

Source code in lazyllm/tools/rag/default_index.py
    @override
    def update(self, nodes: List[DocNode]) -> None:
        """\ 
根据提供的节点列表更新索引中的内容。具体行为由子类或外部实现填充(此处为空实现,需在实际使用中覆盖/扩展)。

Args:
    nodes (List[DocNode]): 需要更新(新增或替换)的文档节点列表。
"""
        pass

lazyllm.tools.Reranker

Bases: ModuleBase, _PostProcess

用于创建节点(文档)后处理和重排序的模块。

Parameters:

  • name (str, default: 'ModuleReranker' ) –

    用于后处理和重排序过程的排序器类型。默认为 'ModuleReranker'。

  • target(str)

    已废弃参数,仅用于提示用户。

  • output_format (Optional[str], default: None ) –

    代表输出格式,默认为None,可选值有 'content' 和 'dict',其中 content 对应输出格式为字符串,dict 对应字典。

  • join (Union[bool, str], default: False ) –

    是否联合输出的 k 个节点,当输出格式为 content 时,如果设置该值为 True,则输出一个长字符串,如果设置为 False 则输出一个字符串列表,其中每个字符串对应每个节点的文本内容。当输出格式是 dict 时,不能联合输出,此时join默认为False,,将输出一个字典,包括'content、'embedding'、'metadata'三个key。

  • kwargs

    传递给重新排序器实例化的其他关键字参数。

详细解释排序器类型

  • Reranker: 实例化一个具有待排序的文档节点node列表和 query的 SentenceTransformerRerank 重排序器。
  • KeywordFilter: 实例化一个具有指定必需和排除关键字的 KeywordNodePostprocessor。它根据这些关键字的存在或缺失来过滤节点。

Examples:

>>> import lazyllm
>>> from lazyllm.tools import Document, Reranker, Retriever, DocNode
>>> m = lazyllm.OnlineEmbeddingModule()
>>> documents = Document(dataset_path='/path/to/user/data', embed=m, manager=False)
>>> retriever = Retriever(documents, group_name='CoarseChunk', similarity='bm25', similarity_cut_off=0.01, topk=6)
>>> reranker = Reranker(DocNode(text=user_data),query="user query")
>>> ppl = lazyllm.ActionModule(retriever, reranker)
>>> ppl.start()
>>> print(ppl("user query"))
Source code in lazyllm/tools/rag/rerank.py
class Reranker(ModuleBase, _PostProcess):
    """用于创建节点(文档)后处理和重排序的模块。

Args:
    name: 用于后处理和重排序过程的排序器类型。默认为 'ModuleReranker'。
    target(str):已废弃参数,仅用于提示用户。
    output_format: 代表输出格式,默认为None,可选值有 'content' 和 'dict',其中 content 对应输出格式为字符串,dict 对应字典。
    join: 是否联合输出的 k 个节点,当输出格式为 content 时,如果设置该值为 True,则输出一个长字符串,如果设置为 False 则输出一个字符串列表,其中每个字符串对应每个节点的文本内容。当输出格式是 dict 时,不能联合输出,此时join默认为False,,将输出一个字典,包括'content、'embedding'、'metadata'三个key。
    kwargs: 传递给重新排序器实例化的其他关键字参数。

详细解释排序器类型

  - Reranker: 实例化一个具有待排序的文档节点node列表和 query的 SentenceTransformerRerank 重排序器。
  - KeywordFilter: 实例化一个具有指定必需和排除关键字的 KeywordNodePostprocessor。它根据这些关键字的存在或缺失来过滤节点。


Examples:

    >>> import lazyllm
    >>> from lazyllm.tools import Document, Reranker, Retriever, DocNode
    >>> m = lazyllm.OnlineEmbeddingModule()
    >>> documents = Document(dataset_path='/path/to/user/data', embed=m, manager=False)
    >>> retriever = Retriever(documents, group_name='CoarseChunk', similarity='bm25', similarity_cut_off=0.01, topk=6)
    >>> reranker = Reranker(DocNode(text=user_data),query="user query")
    >>> ppl = lazyllm.ActionModule(retriever, reranker)
    >>> ppl.start()
    >>> print(ppl("user query"))
    """
    registered_reranker = dict()

    def __new__(cls, name: str = "ModuleReranker", *args, **kwargs):
        assert name in cls.registered_reranker, f"Reranker: {name} is not registered, please register first."
        item = cls.registered_reranker[name]
        if isinstance(item, type) and issubclass(item, Reranker):
            return super(Reranker, cls).__new__(item)
        else:
            return super(Reranker, cls).__new__(cls)

    def __init__(self, name: str = "ModuleReranker", target: Optional[str] = None,
                 output_format: Optional[str] = None, join: Union[bool, str] = False, **kwargs) -> None:
        super().__init__()
        self._name = name
        self._kwargs = kwargs
        lazyllm.deprecated(bool(target), '`target` parameter of reranker')
        _PostProcess.__init__(self, output_format, join)

    def forward(self, nodes: List[DocNode], query: str = "") -> List[DocNode]:
        results = self.registered_reranker[self._name](nodes, query=query, **self._kwargs)
        LOG.debug(f"Rerank use `{self._name}` and get nodes: {results}")
        return self._post_process(results)

    @classmethod
    def register_reranker(
        cls: "Reranker", func: Optional[Callable] = None, batch: bool = False
    ):
        """是一个类装饰器工厂方法,它的核心作用是为 Reranker 类提供灵活的排序算法注册机制
Args:
    func (Optional[Callable]):  要注册的排序函数或排序器类。当使用装饰器语法(@)时可省略。
    batch (bool):是否批量处理节点。默认为False,表示逐节点处理。


Examples:

    @Reranker.register_reranker
    def my_reranker(node: DocNode, **kwargs):
        return node.score * 0.8  # 自定义分数计算
    """
        def decorator(f):
            if isinstance(f, type):
                cls.registered_reranker[f.__name__] = f
                return f
            else:
                def wrapper(nodes, **kwargs):
                    if batch:
                        return f(nodes, **kwargs)
                    else:
                        results = [f(node, **kwargs) for node in nodes]
                        return [result for result in results if result]

                cls.registered_reranker[f.__name__] = wrapper
                return wrapper

        return decorator(func) if func else decorator

register_reranker(func=None, batch=False) classmethod

是一个类装饰器工厂方法,它的核心作用是为 Reranker 类提供灵活的排序算法注册机制 Args: func (Optional[Callable]): 要注册的排序函数或排序器类。当使用装饰器语法(@)时可省略。 batch (bool):是否批量处理节点。默认为False,表示逐节点处理。

Examples:

@Reranker.register_reranker
def my_reranker(node: DocNode, **kwargs):
    return node.score * 0.8  # 自定义分数计算
Source code in lazyllm/tools/rag/rerank.py
    @classmethod
    def register_reranker(
        cls: "Reranker", func: Optional[Callable] = None, batch: bool = False
    ):
        """是一个类装饰器工厂方法,它的核心作用是为 Reranker 类提供灵活的排序算法注册机制
Args:
    func (Optional[Callable]):  要注册的排序函数或排序器类。当使用装饰器语法(@)时可省略。
    batch (bool):是否批量处理节点。默认为False,表示逐节点处理。


Examples:

    @Reranker.register_reranker
    def my_reranker(node: DocNode, **kwargs):
        return node.score * 0.8  # 自定义分数计算
    """
        def decorator(f):
            if isinstance(f, type):
                cls.registered_reranker[f.__name__] = f
                return f
            else:
                def wrapper(nodes, **kwargs):
                    if batch:
                        return f(nodes, **kwargs)
                    else:
                        results = [f(node, **kwargs) for node in nodes]
                        return [result for result in results if result]

                cls.registered_reranker[f.__name__] = wrapper
                return wrapper

        return decorator(func) if func else decorator

lazyllm.tools.Retriever

Bases: ModuleBase, _PostProcess

创建一个用于文档查询和检索的检索模块。此构造函数初始化一个检索模块,该模块根据指定的相似度度量配置文档检索过程。

Parameters:

  • doc (object) –

    文档模块实例。该文档模块可以是单个实例,也可以是一个实例的列表。如果是单个实例,表示对单个Document进行检索,如果是实例的列表,则表示对多个Document进行检索。

  • group_name (str) –

    在哪个 node group 上进行检索。

  • similarity (Optional[str], default: None ) –

    用于设置文档检索的相似度函数。默认为 'dummy'。候选集包括 ["bm25", "bm25_chinese", "cosine"]。

  • similarity_cut_off (Union[float, Dict[str, float]], default: float('-inf') ) –

    当相似度低于指定值时丢弃该文档。在多 embedding 场景下,如果需要对不同的 embedding 指定不同的值,则需要使用字典的方式指定,key 表示指定的是哪个 embedding,value 表示相应的阈值。如果所有的 embedding 使用同一个阈值,则只指定一个数值即可。

  • index (str, default: 'default' ) –

    用于文档检索的索引类型。目前仅支持 'default'。

  • topk (int, default: 6 ) –

    表示取相似度最高的多少篇文档。

  • embed_keys (Optional[List[str]], default: None ) –

    表示通过哪些 embedding 做检索,不指定表示用全部 embedding 进行检索。

  • output_format (Optional[str], default: None ) –

    代表输出格式,默认为None,可选值有 'content' 和 'dict',其中 content 对应输出格式为字符串,dict 对应字典。

  • join (Union[bool, str], default: False ) –

    是否联合输出的 k 个节点,当输出格式为 content 时,如果设置该值为 True,则输出一个长字符串,如果设置为 False 则输出一个字符串列表,其中每个字符串对应每个节点的文本内容。当输出格式是 dict 时,不能联合输出,此时join默认为False,,将输出一个字典,包括'content、'embedding'、'metadata'三个key。

其中 group_name 有三个内置的切分策略,都是使用 SentenceSplitter 做切分,区别在于块大小不同:

  • CoarseChunk: 块大小为 1024,重合长度为 100
  • MediumChunk: 块大小为 256,重合长度为 25
  • FineChunk: 块大小为 128,重合长度为 12

此外,LazyLLM提供了内置的Image节点组存储了所有图像节点,支持图像嵌入和检索。

Examples:

>>> import lazyllm
>>> from lazyllm.tools import Retriever, Document, SentenceSplitter
>>> m = lazyllm.OnlineEmbeddingModule()
>>> documents = Document(dataset_path='/path/to/user/data', embed=m, manager=False)
>>> rm = Retriever(documents, group_name='CoarseChunk', similarity='bm25', similarity_cut_off=0.01, topk=6)
>>> rm.start()
>>> print(rm("user query"))
>>> m1 = lazyllm.TrainableModule('bge-large-zh-v1.5').start()
>>> document1 = Document(dataset_path='/path/to/user/data', embed={'online':m , 'local': m1}, manager=False)
>>> document1.create_node_group(name='sentences', transform=SentenceSplitter, chunk_size=1024, chunk_overlap=100)
>>> retriever = Retriever(document1, group_name='sentences', similarity='cosine', similarity_cut_off=0.4, embed_keys=['local'], topk=3)
>>> print(retriever("user query"))
>>> document2 = Document(dataset_path='/path/to/user/data', embed={'online':m , 'local': m1}, manager=False)
>>> document2.create_node_group(name='sentences', transform=SentenceSplitter, chunk_size=512, chunk_overlap=50)
>>> retriever2 = Retriever([document1, document2], group_name='sentences', similarity='cosine', similarity_cut_off=0.4, embed_keys=['local'], topk=3)
>>> print(retriever2("user query"))
>>>
>>> filters = {
>>>     "author": ["A", "B", "C"],
>>>     "public_year": [2002, 2003, 2004],
>>> }
>>> document3 = Document(dataset_path='/path/to/user/data', embed={'online':m , 'local': m1}, manager=False)
>>> document3.create_node_group(name='sentences', transform=SentenceSplitter, chunk_size=512, chunk_overlap=50)
>>> retriever3 = Retriever([document1, document3], group_name='sentences', similarity='cosine', similarity_cut_off=0.4, embed_keys=['local'], topk=3)
>>> print(retriever3(query="user query", filters=filters))
>>> document4 = Document(dataset_path='/path/to/user/data', embed=lazyllm.TrainableModule('siglip'))
>>> retriever4 = Retriever(document4, group_name='Image', similarity='cosine')
>>> nodes = retriever4("user query")
>>> print([node.get_content() for node in nodes])
>>> document5 = Document(dataset_path='/path/to/user/data', embed=m, manager=False)
>>> rm = Retriever(document5, group_name='CoarseChunk', similarity='bm25_chinese', similarity_cut_off=0.01, topk=3, output_format='content')
>>> rm.start()
>>> print(rm("user query"))
>>> document6 = Document(dataset_path='/path/to/user/data', embed=m, manager=False)
>>> rm = Retriever(document6, group_name='CoarseChunk', similarity='bm25_chinese', similarity_cut_off=0.01, topk=3, output_format='content', join=True)
>>> rm.start()
>>> print(rm("user query"))
>>> document7 = Document(dataset_path='/path/to/user/data', embed=m, manager=False)
>>> rm = Retriever(document7, group_name='CoarseChunk', similarity='bm25_chinese', similarity_cut_off=0.01, topk=3, output_format='dict')
>>> rm.start()
>>> print(rm("user query"))
Source code in lazyllm/tools/rag/retriever.py
class Retriever(ModuleBase, _PostProcess):
    """
创建一个用于文档查询和检索的检索模块。此构造函数初始化一个检索模块,该模块根据指定的相似度度量配置文档检索过程。

Args:
    doc: 文档模块实例。该文档模块可以是单个实例,也可以是一个实例的列表。如果是单个实例,表示对单个Document进行检索,如果是实例的列表,则表示对多个Document进行检索。
    group_name: 在哪个 node group 上进行检索。
    similarity: 用于设置文档检索的相似度函数。默认为 'dummy'。候选集包括 ["bm25", "bm25_chinese", "cosine"]。
    similarity_cut_off: 当相似度低于指定值时丢弃该文档。在多 embedding 场景下,如果需要对不同的 embedding 指定不同的值,则需要使用字典的方式指定,key 表示指定的是哪个 embedding,value 表示相应的阈值。如果所有的 embedding 使用同一个阈值,则只指定一个数值即可。
    index: 用于文档检索的索引类型。目前仅支持 'default'。
    topk: 表示取相似度最高的多少篇文档。
    embed_keys: 表示通过哪些 embedding 做检索,不指定表示用全部 embedding 进行检索。
    target:目标组名,将结果转换到目标组。
    output_format: 代表输出格式,默认为None,可选值有 'content' 和 'dict',其中 content 对应输出格式为字符串,dict 对应字典。
    join: 是否联合输出的 k 个节点,当输出格式为 content 时,如果设置该值为 True,则输出一个长字符串,如果设置为 False 则输出一个字符串列表,其中每个字符串对应每个节点的文本内容。当输出格式是 dict 时,不能联合输出,此时join默认为False,,将输出一个字典,包括'content、'embedding'、'metadata'三个key。

其中 `group_name` 有三个内置的切分策略,都是使用 `SentenceSplitter` 做切分,区别在于块大小不同:

- CoarseChunk: 块大小为 1024,重合长度为 100
- MediumChunk: 块大小为 256,重合长度为 25
- FineChunk: 块大小为 128,重合长度为 12

此外,LazyLLM提供了内置的`Image`节点组存储了所有图像节点,支持图像嵌入和检索。


Examples:

    >>> import lazyllm
    >>> from lazyllm.tools import Retriever, Document, SentenceSplitter
    >>> m = lazyllm.OnlineEmbeddingModule()
    >>> documents = Document(dataset_path='/path/to/user/data', embed=m, manager=False)
    >>> rm = Retriever(documents, group_name='CoarseChunk', similarity='bm25', similarity_cut_off=0.01, topk=6)
    >>> rm.start()
    >>> print(rm("user query"))
    >>> m1 = lazyllm.TrainableModule('bge-large-zh-v1.5').start()
    >>> document1 = Document(dataset_path='/path/to/user/data', embed={'online':m , 'local': m1}, manager=False)
    >>> document1.create_node_group(name='sentences', transform=SentenceSplitter, chunk_size=1024, chunk_overlap=100)
    >>> retriever = Retriever(document1, group_name='sentences', similarity='cosine', similarity_cut_off=0.4, embed_keys=['local'], topk=3)
    >>> print(retriever("user query"))
    >>> document2 = Document(dataset_path='/path/to/user/data', embed={'online':m , 'local': m1}, manager=False)
    >>> document2.create_node_group(name='sentences', transform=SentenceSplitter, chunk_size=512, chunk_overlap=50)
    >>> retriever2 = Retriever([document1, document2], group_name='sentences', similarity='cosine', similarity_cut_off=0.4, embed_keys=['local'], topk=3)
    >>> print(retriever2("user query"))
    >>>
    >>> filters = {
    >>>     "author": ["A", "B", "C"],
    >>>     "public_year": [2002, 2003, 2004],
    >>> }
    >>> document3 = Document(dataset_path='/path/to/user/data', embed={'online':m , 'local': m1}, manager=False)
    >>> document3.create_node_group(name='sentences', transform=SentenceSplitter, chunk_size=512, chunk_overlap=50)
    >>> retriever3 = Retriever([document1, document3], group_name='sentences', similarity='cosine', similarity_cut_off=0.4, embed_keys=['local'], topk=3)
    >>> print(retriever3(query="user query", filters=filters))
    >>> document4 = Document(dataset_path='/path/to/user/data', embed=lazyllm.TrainableModule('siglip'))
    >>> retriever4 = Retriever(document4, group_name='Image', similarity='cosine')
    >>> nodes = retriever4("user query")
    >>> print([node.get_content() for node in nodes])
    >>> document5 = Document(dataset_path='/path/to/user/data', embed=m, manager=False)
    >>> rm = Retriever(document5, group_name='CoarseChunk', similarity='bm25_chinese', similarity_cut_off=0.01, topk=3, output_format='content')
    >>> rm.start()
    >>> print(rm("user query"))
    >>> document6 = Document(dataset_path='/path/to/user/data', embed=m, manager=False)
    >>> rm = Retriever(document6, group_name='CoarseChunk', similarity='bm25_chinese', similarity_cut_off=0.01, topk=3, output_format='content', join=True)
    >>> rm.start()
    >>> print(rm("user query"))
    >>> document7 = Document(dataset_path='/path/to/user/data', embed=m, manager=False)
    >>> rm = Retriever(document7, group_name='CoarseChunk', similarity='bm25_chinese', similarity_cut_off=0.01, topk=3, output_format='dict')
    >>> rm.start()
    >>> print(rm("user query"))
    """
    def __init__(self, doc: object, group_name: str, similarity: Optional[str] = None,
                 similarity_cut_off: Union[float, Dict[str, float]] = float("-inf"), index: str = "default",
                 topk: int = 6, embed_keys: Optional[List[str]] = None, target: Optional[str] = None,
                 output_format: Optional[str] = None, join: Union[bool, str] = False, **kwargs):
        super().__init__()

        if similarity:
            _, mode, _ = registered_similarities[similarity]
        else:
            mode = 'embedding'  # TODO FIXME XXX should be removed after similarity args refactor
        group_name, target = str(group_name), (str(target) if target else None)

        self._docs: List[Document] = [doc] if isinstance(doc, Document) else doc
        for doc in self._docs:
            assert isinstance(doc, (Document, UrlDocument)), 'Only Document or List[Document] are supported'
            if isinstance(doc, UrlDocument): continue
            self._submodules.append(doc)
            if mode == 'embedding' and embed_keys is None:
                embed_keys = list(doc._impl.embed.keys())
            doc.activate_group(group_name, embed_keys)
            if target: doc.activate_group(target)

        self._group_name = group_name
        self._similarity = similarity  # similarity function str
        self._similarity_cut_off = similarity_cut_off
        self._index = index
        self._topk = topk
        self._similarity_kw = kwargs  # kw parameters
        self._embed_keys = embed_keys
        self._target = target
        _PostProcess.__init__(self, output_format, join)

    @once_wrapper
    def _lazy_init(self):
        docs = [doc for doc in self._docs if isinstance(doc, UrlDocument) or self._group_name in doc._impl.node_groups
                or self._group_name in DocImpl._builtin_node_groups or self._group_name in DocImpl._global_node_groups]
        if not docs: raise RuntimeError(f'Group {self._group_name} not found in document {self._docs}')
        self._docs = docs

    def forward(
            self, query: str, filters: Optional[Dict[str, Union[str, int, List, Set]]] = None,
            **kwargs
    ) -> Union[List[DocNode], str]:
        self._lazy_init()
        all_nodes: List[DocNode] = []
        for doc in self._docs:
            nodes = doc.forward(query=query, group_name=self._group_name, similarity=self._similarity,
                                similarity_cut_off=self._similarity_cut_off, index=self._index,
                                topk=self._topk, similarity_kws=self._similarity_kw, embed_keys=self._embed_keys,
                                filters=filters, **kwargs)
            if nodes and self._target and self._target != nodes[0]._group:
                nodes = doc.find(self._target)(nodes)
            all_nodes.extend(nodes)
        return self._post_process(all_nodes)

lazyllm.tools.rag.retriever.TempDocRetriever

Bases: ModuleBase, _PostProcess

临时文档检索器,继承自 ModuleBase 和 _PostProcess,用于快速处理临时文件并执行检索任务。

Parameters:

  • embed (Callable, default: None ) –

    嵌入函数。

  • output_format (Optional[str], default: None ) –

    结果输出格式(如json),可选默认为None

  • join (Union[bool, str], default: False ) –

    是否合并多段结果(True或用分隔符如"

")

Examples:

>>> import lazyllm
>>> from lazyllm.tools import TempDocRetriever, Document, SentenceSplitter
>>> retriever = TempDocRetriever(output_format="text", join="
---------------
")
    retriever.create_node_group(transform=lambda text: [s.strip() for s in text.split("。") if s] )
    retriever.add_subretriever(group=Document.MediumChunk, topk=3)
    files = ["机器学习是AI的核心领域。深度学习是其重要分支。"]
    results = retriever.forward(files, "什么是机器学习?")
    print(results)
Source code in lazyllm/tools/rag/retriever.py
class TempDocRetriever(ModuleBase, _PostProcess):
    """
临时文档检索器,继承自 ModuleBase 和 _PostProcess,用于快速处理临时文件并执行检索任务。

Args:
    embed:嵌入函数。
    output_format:结果输出格式(如json),可选默认为None
    join:是否合并多段结果(True或用分隔符如"
")


Examples:

    >>> import lazyllm
    >>> from lazyllm.tools import TempDocRetriever, Document, SentenceSplitter
    >>> retriever = TempDocRetriever(output_format="text", join="
    ---------------
    ")
        retriever.create_node_group(transform=lambda text: [s.strip() for s in text.split("。") if s] )
        retriever.add_subretriever(group=Document.MediumChunk, topk=3)
        files = ["机器学习是AI的核心领域。深度学习是其重要分支。"]
        results = retriever.forward(files, "什么是机器学习?")
        print(results)
    """
    def __init__(self, embed: Callable = None, output_format: Optional[str] = None, join: Union[bool, str] = False):
        super().__init__()
        self._doc = Document(doc_files=[])
        self._embed = embed
        self._node_groups = []
        _PostProcess.__init__(self, output_format, join)

    def create_node_group(self, name: str = None, *, transform: Callable, parent: str = LAZY_ROOT_NAME,
                          trans_node: bool = None, num_workers: int = 0, **kwargs):
        """
创建具有特定处理流程的节点组。

Args:
    name (str): 节点组名称,None时自动生成。
    transform (Callable): 该组文档的处理函数。
    parent (str): 父组名称,默认为根组。
    trans_node (bool): 是否转换节点,None时继承父组设置。
    num_workers (int): 并行处理worker数,0表示串行。
    **kwargs: 其他组参数。
"""
        self._doc.create_node_group(name, transform=transform, parent=parent,
                                    trans_node=trans_node, num_workers=num_workers, **kwargs)
        return self

    def add_subretriever(self, group: str, **kwargs):
        """
添加带搜索配置的子检索器。

Args:
    group (str): 目标节点组名称。
    **kwargs: 检索器参数(如similarity='cosine')。

**Returns:**

- self: 支持链式调用。
"""
        if 'similarity' not in kwargs: kwargs['similarity'] = ('cosine' if self._embed else 'bm25')
        self._node_groups.append((group, kwargs))
        return self

    @functools.lru_cache    # noqa: B019
    def _get_retrievers(self, doc_files: List[str]):
        active_node_groups = self._node_groups or [[Document.MediumChunk,
                                                    dict(similarity=('cosine' if self._embed else 'bm25'))]]
        doc = Document(embed=self._embed, doc_files=doc_files)
        doc._impl.node_groups = self._doc._impl.node_groups
        retrievers = [Retriever(doc, name, **kw) for (name, kw) in active_node_groups]
        return retrievers

    def forward(self, files: Union[str, List[str]], query: str):
        if isinstance(files, str): files = [files]
        retrievers = self._get_retrievers(doc_files=tuple(set(files)))
        r = lazyllm.parallel(*retrievers).sum
        return self._post_process(r(query))

add_subretriever(group, **kwargs)

添加带搜索配置的子检索器。

Parameters:

  • group (str) –

    目标节点组名称。

  • **kwargs

    检索器参数(如similarity='cosine')。

Returns:

  • self: 支持链式调用。
Source code in lazyllm/tools/rag/retriever.py
    def add_subretriever(self, group: str, **kwargs):
        """
添加带搜索配置的子检索器。

Args:
    group (str): 目标节点组名称。
    **kwargs: 检索器参数(如similarity='cosine')。

**Returns:**

- self: 支持链式调用。
"""
        if 'similarity' not in kwargs: kwargs['similarity'] = ('cosine' if self._embed else 'bm25')
        self._node_groups.append((group, kwargs))
        return self

create_node_group(name=None, *, transform, parent=LAZY_ROOT_NAME, trans_node=None, num_workers=0, **kwargs)

创建具有特定处理流程的节点组。

Parameters:

  • name (str, default: None ) –

    节点组名称,None时自动生成。

  • transform (Callable) –

    该组文档的处理函数。

  • parent (str, default: LAZY_ROOT_NAME ) –

    父组名称,默认为根组。

  • trans_node (bool, default: None ) –

    是否转换节点,None时继承父组设置。

  • num_workers (int, default: 0 ) –

    并行处理worker数,0表示串行。

  • **kwargs

    其他组参数。

Source code in lazyllm/tools/rag/retriever.py
    def create_node_group(self, name: str = None, *, transform: Callable, parent: str = LAZY_ROOT_NAME,
                          trans_node: bool = None, num_workers: int = 0, **kwargs):
        """
创建具有特定处理流程的节点组。

Args:
    name (str): 节点组名称,None时自动生成。
    transform (Callable): 该组文档的处理函数。
    parent (str): 父组名称,默认为根组。
    trans_node (bool): 是否转换节点,None时继承父组设置。
    num_workers (int): 并行处理worker数,0表示串行。
    **kwargs: 其他组参数。
"""
        self._doc.create_node_group(name, transform=transform, parent=parent,
                                    trans_node=trans_node, num_workers=num_workers, **kwargs)
        return self

lazyllm.tools.rag.retriever.UrlDocument

Bases: ModuleBase

UrlDocument类继承自ModuleBase,用于通过指定的URL和名称管理远程文档资源。
内部通过lazyllm的UrlModule代理实际调用,支持文档查找、检索和活跃节点分组查询。

Parameters:

  • url (str) –

    远程文档资源的访问URL。

  • name (str, default: None ) –

    当前文档分组名称,用于标识文档分组。

Source code in lazyllm/tools/rag/document.py
class UrlDocument(ModuleBase):
    """UrlDocument类继承自ModuleBase,用于通过指定的URL和名称管理远程文档资源。  
内部通过lazyllm的UrlModule代理实际调用,支持文档查找、检索和活跃节点分组查询。  

Args:
    url (str): 远程文档资源的访问URL。
    name (str): 当前文档分组名称,用于标识文档分组。
"""
    def __init__(self, url: str, name: str = None):
        super().__init__()
        self._missing_keys = set(dir(Document)) - set(dir(UrlDocument))
        self._manager = lazyllm.UrlModule(url=ensure_call_endpoint(url))
        self._curr_group = name or DocListManager.DEFAULT_GROUP_NAME

    def _forward(self, func_name: str, *args, **kwargs):
        args = (self._curr_group, func_name, *args)
        return self._manager._call("__call__", *args, **kwargs)

    def find(self, target) -> Callable:
        """生成一个部分应用函数,用于在当前文档组中查找指定目标。

Args:
    target (str): 需要查找的目标标识。

**Returns:**

- Callable: 调用时会执行查找操作的部分应用函数。
"""
        return functools.partial(self._forward, 'find', group=target)

    def forward(self, *args, **kw):
        return self._forward('retrieve', *args, **kw)

    @cached_property
    def active_node_groups(self):
        return self._forward('active_node_groups')

    def __getattr__(self, name):
        if name in self._missing_keys:
            raise RuntimeError(f'Document generated with url and name has no attribute `{name}`')

find(target)

生成一个部分应用函数,用于在当前文档组中查找指定目标。

Parameters:

  • target (str) –

    需要查找的目标标识。

Returns:

  • Callable: 调用时会执行查找操作的部分应用函数。
Source code in lazyllm/tools/rag/document.py
    def find(self, target) -> Callable:
        """生成一个部分应用函数,用于在当前文档组中查找指定目标。

Args:
    target (str): 需要查找的目标标识。

**Returns:**

- Callable: 调用时会执行查找操作的部分应用函数。
"""
        return functools.partial(self._forward, 'find', group=target)

lazyllm.tools.rag.DocManager

Bases: ModuleBase

DocManager类管理文档列表及相关操作,并通过API提供文档上传、删除、分组等功能。

Parameters:

  • dlm (DocListManager) –

    文档列表管理器,用于处理具体的文档操作。

Source code in lazyllm/tools/rag/doc_manager.py
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
class DocManager(lazyllm.ModuleBase):
    """
DocManager类管理文档列表及相关操作,并通过API提供文档上传、删除、分组等功能。

Args:
    dlm (DocListManager): 文档列表管理器,用于处理具体的文档操作。

"""
    def __init__(self, dlm: DocListManager) -> None:
        super().__init__()
        # disable path monitoring in case of competition adding/deleting files
        self._manager = dlm
        self._manager.enable_path_monitoring = False

    def __reduce__(self):
        self._manager.enable_path_monitoring = False
        return (__class__, (self._manager,))

    @app.get("/", response_model=BaseResponse, summary="docs")
    def document(self):
        """
提供默认文档页面的重定向接口。

**Returns:**

- RedirectResponse: 重定向到 `/docs` 页面。
"""
        return RedirectResponse(url="/docs")

    @app.get("/list_kb_groups")
    def list_kb_groups(self):
        """
列出所有文档分组的接口。

**Returns:**

- BaseResponse: 包含所有文档分组的数据。
"""
        try:
            return BaseResponse(data=self._manager.list_all_kb_group())
        except Exception as e:
            return BaseResponse(code=500, msg=str(e), data=None)

    # returns an error message if invalid
    @staticmethod
    def _validate_metadata(metadata: Dict) -> Optional[str]:
        if metadata.get(RAG_DOC_ID):
            return f"metadata MUST not contain key `{RAG_DOC_ID}`"
        if metadata.get(RAG_DOC_PATH):
            return f"metadata MUST not contain key `{RAG_DOC_PATH}`"
        return None

    def _gen_unique_filepath(self, file_path: str) -> str:
        suffix = os.path.splitext(file_path)[1]
        prefix = file_path[0: len(file_path) - len(suffix)]
        pattern = f"{prefix}%{suffix}"
        MAX_TRIES = 10000
        exist_paths = set(self._manager.get_existing_paths_by_pattern(pattern))
        if file_path not in exist_paths:
            return file_path
        for i in range(1, MAX_TRIES):
            new_path = f"{prefix}-{i}{suffix}"
            if new_path not in exist_paths:
                return new_path
        return f"{str(uuid.uuid4())}{suffix}"

    @app.post("/upload_files")
    def upload_files(self, files: List[UploadFile], override: bool = False,  # noqa C901
                     metadatas: Optional[str] = None, user_path: Optional[str] = None):
        """
上传文件并更新其状态的接口。可以同时上传多个文件。

Args:
    files (List[UploadFile]): 上传的文件列表。
    override (bool): 是否覆盖已存在的文件。默认为False。
    metadatas (Optional[str]): 文件的元数据,JSON格式。
    user_path (Optional[str]): 用户自定义的文件上传路径。

**Returns:**

- BaseResponse: 上传结果和文件ID。
"""
        try:
            if user_path: user_path = user_path.lstrip('/')
            if metadatas:
                metadatas: Optional[List[Dict[str, str]]] = json.loads(metadatas)
                if len(files) != len(metadatas):
                    return BaseResponse(code=400, msg='Length of files and metadatas should be the same',
                                        data=None)
                for idx, mt in enumerate(metadatas):
                    err_msg = self._validate_metadata(mt)
                    if err_msg:
                        return BaseResponse(code=400, msg=f'file [{files[idx].filename}]: {err_msg}', data=None)
            file_paths = [os.path.join(self._manager._path, user_path or '', file.filename) for file in files]
            paths_is_new = [True] * len(file_paths)
            if override is True:
                is_success, msg, paths_is_new = self._manager.validate_paths(file_paths)
                if not is_success:
                    return BaseResponse(code=500, msg=msg, data=None)
            directorys = set(os.path.dirname(path) for path in file_paths)
            [os.makedirs(directory, exist_ok=True) for directory in directorys if directory]
            ids, results = [], []
            for i in range(len(files)):
                file_path = file_paths[i]
                content = files[i].file.read()
                metadata = metadatas[i] if metadatas else None
                if override is False:
                    file_path = self._gen_unique_filepath(file_path)
                with open(file_path, 'wb') as f: f.write(content)
                msg = "success"
                doc_id = gen_docid(file_path)
                if paths_is_new[i]:
                    docs = self._manager.add_files(
                        [file_path], metadatas=[metadata], status=DocListManager.Status.success)
                    if not docs:
                        msg = f"Failed: path {file_path} already exists in Database."
                else:
                    self._manager.update_kb_group(cond_file_ids=[doc_id], new_need_reparse=True)
                    msg = f"Success: path {file_path} will be reparsed."
                ids.append(doc_id)
                results.append(msg)
            return BaseResponse(data=[ids, results])
        except Exception as e:
            lazyllm.LOG.error(f'upload_files exception: {e}')
            return BaseResponse(code=500, msg=str(e), data=None)

    class AddFilesRequest(BaseModel):
        files: List[str]
        group_name: Optional[str] = None
        metadatas: Optional[str] = None

    @app.post("/add_files")
    def add_files(self, request: AddFilesRequest):
        """
批量添加文件。
Args:
    files (List[UploadFile]): 上传的文件列表。
    group_name (str): 目标知识库分组名称,为空时不添加到分组。
    metadatas (Optional[str]): 文件的元数据,JSON格式。
**Returns:**

- BaseResponse:返回所有输入文件对应的唯一文件ID列表,包括新增和已存在的文件。若出现异常,则返回错误码和异常信息。
"""
        files = request.files
        group_name = request.group_name
        metadatas = request.metadatas
        try:
            if metadatas:
                metadatas: Optional[List[Dict[str, str]]] = json.loads(metadatas)
                assert len(files) == len(metadatas), 'Length of files and metadatas should be the same'

            exists_files_info = self._manager.list_files(limit=None, details=True, status=DocListManager.Status.all)
            exists_files_info = {row[2]: row[0] for row in exists_files_info}

            exist_ids = []
            new_files = []
            new_metadatas = []
            id_mapping = {}

            for idx, file in enumerate(files):
                if os.path.exists(file):
                    exist_id = exists_files_info.get(file, None)
                    if exist_id:
                        update_kws = dict(fileid=exist_id, status=DocListManager.Status.success)
                        if metadatas: update_kws["meta"] = json.dumps(metadatas[idx])
                        self._manager.update_file_message(**update_kws)
                        exist_ids.append(exist_id)
                        id_mapping[file] = exist_id
                    else:
                        new_files.append(file)
                        if metadatas:
                            new_metadatas.append(metadatas[idx])
                else:
                    id_mapping[file] = None

            new_ids = self._manager.add_files(new_files, metadatas=new_metadatas, status=DocListManager.Status.success)
            if group_name:
                self._manager.add_files_to_kb_group(new_ids + exist_ids, group=group_name)

            for file, new_id in zip(new_files, new_ids):
                id_mapping[file] = new_id
            return_ids = [id_mapping[file] for file in files]

            return BaseResponse(data=return_ids)
        except Exception as e:
            return BaseResponse(code=500, msg=str(e), data=None)

    @app.get("/list_files")
    def list_files(self, limit: Optional[int] = None, details: bool = True, alive: Optional[bool] = None):
        """
列出已上传文件的接口。

Args:
    limit (Optional[int]): 返回的文件数量限制。默认为None。
    details (bool): 是否返回详细信息。默认为True。
    alive (Optional[bool]): 如果为True,只返回未删除的文件。默认为None。

**Returns:**

- BaseResponse: 文件列表数据。
"""
        try:
            status = [DocListManager.Status.success, DocListManager.Status.waiting, DocListManager.Status.working,
                      DocListManager.Status.failed] if alive else DocListManager.Status.all
            return BaseResponse(data=self._manager.list_files(limit=limit, details=details, status=status))
        except Exception as e:
            return BaseResponse(code=500, msg=str(e), data=None)

    @app.get("/reparse_files")
    def reparse_files(self, file_ids: List[str], group_name: Optional[str] = None):
        try:
            self._manager.update_need_reparsing(file_ids, group_name)
            return BaseResponse()
        except Exception as e:
            return BaseResponse(code=500, msg=str(e), data=None)

    @app.get("/list_files_in_group")
    def list_files_in_group(self, group_name: Optional[str] = None,
                            limit: Optional[int] = None, alive: Optional[bool] = None):
        """
列出指定分组中文件的接口。

Args:
    group_name (Optional[str]): 文件分组名称。
    limit (Optional[int]): 返回的文件数量限制。默认为None。
    alive (Optional[bool]): 是否只返回未删除的文件。

**Returns:**

- BaseResponse: 分组文件列表。
"""
        try:
            status = [DocListManager.Status.success, DocListManager.Status.waiting, DocListManager.Status.working,
                      DocListManager.Status.failed] if alive else DocListManager.Status.all
            return BaseResponse(data=self._manager.list_kb_group_files(group_name, limit, details=True, status=status))
        except Exception as e:
            return BaseResponse(code=500, msg=str(e) + '\ntraceback:\n' + str(traceback.format_exc()), data=None)

    class FileGroupRequest(BaseModel):
        file_ids: List[str]
        group_name: Optional[str] = Field(None)

    @app.post("/add_files_to_group_by_id")
    def add_files_to_group_by_id(self, request: FileGroupRequest):
        """
通过文件ID将文件添加到指定分组的接口。

Args:
    request (FileGroupRequest): 包含文件ID和分组名称的请求。

**Returns:**

- BaseResponse: 操作结果。
"""
        try:
            self._manager.add_files_to_kb_group(request.file_ids, request.group_name)
            return BaseResponse()
        except Exception as e:
            return BaseResponse(code=500, msg=str(e), data=None)

    @app.post("/add_files_to_group")
    def add_files_to_group(self, files: List[UploadFile], group_name: str, override: bool = False,
                           metadatas: Optional[str] = None, user_path: Optional[str] = None):
        """
将文件上传后直接添加到指定分组的接口。

Args:
    files (List[UploadFile]): 上传的文件列表。
    group_name (str): 要添加到的分组名称。
    override (bool): 是否覆盖已存在的文件。默认为False。
    metadatas (Optional[str]): 文件元数据,JSON格式。
    user_path (Optional[str]): 用户自定义的文件上传路径。

**Returns:**

- BaseResponse: 操作结果和文件ID。
"""
        try:
            response = self.upload_files(files, override=override, metadatas=metadatas, user_path=user_path)
            if response.code != 200: return response
            ids = response.data[0]
            self._manager.add_files_to_kb_group(ids, group_name)
            return BaseResponse(data=ids)
        except Exception as e:
            return BaseResponse(code=500, msg=str(e), data=None)

    @app.post("/delete_files")
    def delete_files(self, request: FileGroupRequest):
        """
删除指定文件的接口。

Args:
    request (FileGroupRequest): 包含文件ID和分组名称的请求。

**Returns:**

- BaseResponse: 删除操作结果。
"""
        try:
            if request.group_name:
                return self.delete_files_from_group(request)
            else:
                documents = self._manager.delete_files(request.file_ids)
                deleted_ids = set([ele.doc_id for ele in documents])
                for doc in documents:
                    if os.path.exists(path := doc.path):
                        os.remove(path)
                results = ["Success" if ele.doc_id in deleted_ids else "Failed" for ele in documents]
                return BaseResponse(data=[request.file_ids, results])
        except Exception as e:
            return BaseResponse(code=500, msg=str(e), data=None)

    @app.post("/delete_files_from_group")
    def delete_files_from_group(self, request: FileGroupRequest):
        """
删除指定分组中的文件的接口。

Args:
    request (FileGroupRequest): 包含文件ID列表和分组名称的请求参数。

**Returns:**

- BaseResponse: 删除操作结果。
"""
        try:
            self._manager.update_kb_group(cond_file_ids=request.file_ids, cond_group=request.group_name,
                                          new_status=DocListManager.Status.deleting)
            return BaseResponse()
        except Exception as e:
            return BaseResponse(code=500, msg=str(e), data=None)

    class AddMetadataRequest(BaseModel):
        doc_ids: List[str]
        kv_pair: Dict[str, Union[bool, int, float, str, list]]

    @app.post("/add_metadata")
    def add_metadata(self, add_metadata_request: AddMetadataRequest):
        """
为指定文档添加或更新元数据的接口。

Args:
    add_metadata_request (AddMetadataRequest): 包含文档ID列表和键值对元数据的请求。

**Returns:**

- BaseResponse: 操作结果信息。
"""
        doc_ids = add_metadata_request.doc_ids
        kv_pair = add_metadata_request.kv_pair
        try:
            docs = self._manager.get_docs(doc_ids)
            if not docs:
                return BaseResponse(code=400, msg="Failed, no doc found")
            doc_meta = {}
            for doc in docs:
                meta_dict = json.loads(doc.meta) if doc.meta else {}
                for k, v in kv_pair.items():
                    if k not in meta_dict or not meta_dict[k]:
                        meta_dict[k] = v
                    elif isinstance(meta_dict[k], list):
                        meta_dict[k].extend(v) if isinstance(v, list) else meta_dict[k].append(v)
                    else:
                        meta_dict[k] = ([meta_dict[k]] + v) if isinstance(v, list) else [meta_dict[k], v]
                doc_meta[doc.doc_id] = meta_dict
            self._manager.set_docs_new_meta(doc_meta)
            return BaseResponse(data=None)
        except Exception as e:
            return BaseResponse(code=500, msg=str(e), data=None)

    class DeleteMetadataRequest(BaseModel):
        doc_ids: List[str]
        keys: Optional[List[str]] = Field(None)
        kv_pair: Optional[Dict[str, Union[bool, int, float, str, list]]] = Field(None)

    def _inplace_del_meta(self, meta_dict, kv_pair: Dict[str, Union[None, bool, int, float, str, list]]):
        # alert: meta_dict is not a deepcopy
        for k, v in kv_pair.items():
            if k not in meta_dict:
                continue
            if v is None:
                meta_dict.pop(k, None)
            elif isinstance(meta_dict[k], list):
                if isinstance(v, (bool, int, float, str)):
                    v = [v]
                # delete v exists in meta_dict[k]
                meta_dict[k] = list(set(meta_dict[k]) - set(v))
            else:
                # old meta[k] not a list, use v as condition to delete the key
                if meta_dict[k] == v:
                    meta_dict.pop(k, None)

    @app.post("/delete_metadata_item")
    def delete_metadata_item(self, del_metadata_request: DeleteMetadataRequest):
        """
删除指定文档的元数据字段或字段值的接口。

Args:
    del_metadata_request (DeleteMetadataRequest): 包含文档ID列表、字段名和键值对删除条件的请求。

**Returns:**

- BaseResponse: 操作结果信息。
"""
        doc_ids = del_metadata_request.doc_ids
        kv_pair = del_metadata_request.kv_pair
        keys = del_metadata_request.keys
        try:
            if keys is not None:
                # convert keys to kv_pair
                if kv_pair:
                    kv_pair.update({k: None for k in keys})
                else:
                    kv_pair = {k: None for k in keys}
            if not kv_pair:
                # clear metadata
                self._manager.set_docs_new_meta({doc_id: {} for doc_id in doc_ids})
            else:
                docs = self._manager.get_docs(doc_ids)
                if not docs:
                    return BaseResponse(code=400, msg="Failed, no doc found")
                doc_meta = {}
                for doc in docs:
                    meta_dict = json.loads(doc.meta) if doc.meta else {}
                    self._inplace_del_meta(meta_dict, kv_pair)
                    doc_meta[doc.doc_id] = meta_dict
                self._manager.set_docs_new_meta(doc_meta)
            return BaseResponse(data=None)
        except Exception as e:
            return BaseResponse(code=500, msg=str(e), data=None)

    class UpdateMetadataRequest(BaseModel):
        doc_ids: List[str]
        kv_pair: Dict[str, Union[bool, int, float, str, list]]

    @app.post("/update_or_create_metadata_keys")
    def update_or_create_metadata_keys(self, update_metadata_request: UpdateMetadataRequest):
        """
更新或创建文档元数据字段的接口。
Args:
    update_metadata_request (UpdateMetadataRequest): 包含文档ID列表和需更新或新增的键值对元数据。

**Returns:**

- BaseResponse: 操作结果信息。
"""
        doc_ids = update_metadata_request.doc_ids
        kv_pair = update_metadata_request.kv_pair
        try:
            docs = self._manager.get_docs(doc_ids)
            if not docs:
                return BaseResponse(code=400, msg="Failed, no doc found")
            for doc in docs:
                doc_meta = {}
                meta_dict = json.loads(doc.meta) if doc.meta else {}
                for k, v in kv_pair.items():
                    meta_dict[k] = v
                doc_meta[doc.doc_id] = meta_dict
            self._manager.set_docs_new_meta(doc_meta)
            return BaseResponse(data=None)
        except Exception as e:
            return BaseResponse(code=500, msg=str(e), data=None)

    class ResetMetadataRequest(BaseModel):
        doc_ids: List[str]
        new_meta: Dict[str, Union[bool, int, float, str, list]]

    @app.post("/reset_metadata")
    def reset_metadata(self, reset_metadata_request: ResetMetadataRequest):
        """
重置指定文档的所有元数据字段。

Args:
    reset_metadata_request (ResetMetadataRequest): 包含文档ID列表和新的元数据字典。

**Returns:**

- BaseResponse: 操作结果信息。
"""
        doc_ids = reset_metadata_request.doc_ids
        new_meta = reset_metadata_request.new_meta
        try:
            docs = self._manager.get_docs(doc_ids)
            if not docs:
                return BaseResponse(code=400, msg="Failed, no doc found")
            self._manager.set_docs_new_meta({doc.doc_id: new_meta for doc in docs})
            return BaseResponse(data=None)
        except Exception as e:
            return BaseResponse(code=500, msg=str(e), data=None)

    class QueryMetadataRequest(BaseModel):
        doc_id: str
        key: Optional[str] = None

    @app.post("/query_metadata")
    def query_metadata(self, query_metadata_request: QueryMetadataRequest):
        """
查询指定文档的元数据。

Args:
    query_metadata_request (QueryMetadataRequest): 请求参数,包含文档ID和可选的字段名。

**Returns:**

- BaseResponse: 若指定了 key 且存在,返回对应字段值;否则返回整个 metadata;key 不存在时报错。
"""
        doc_id = query_metadata_request.doc_id
        key = query_metadata_request.key
        try:
            docs = self._manager.get_docs(doc_id)
            if not docs:
                return BaseResponse(data=None)
            doc = docs[0]
            meta_dict = json.loads(doc.meta) if doc.meta else {}
            if not key:
                return BaseResponse(data=meta_dict)
            if key not in meta_dict:
                return BaseResponse(code=400, msg=f"Failed, key {key} does not exist")
            return BaseResponse(data=meta_dict[key])
        except Exception as e:
            return BaseResponse(code=500, msg=str(e), data=None)

    def __repr__(self):
        return lazyllm.make_repr("Module", "DocManager")

add_files(request)

批量添加文件。 Args: files (List[UploadFile]): 上传的文件列表。 group_name (str): 目标知识库分组名称,为空时不添加到分组。 metadatas (Optional[str]): 文件的元数据,JSON格式。 Returns:

  • BaseResponse:返回所有输入文件对应的唯一文件ID列表,包括新增和已存在的文件。若出现异常,则返回错误码和异常信息。
Source code in lazyllm/tools/rag/doc_manager.py
    @app.post("/add_files")
    def add_files(self, request: AddFilesRequest):
        """
批量添加文件。
Args:
    files (List[UploadFile]): 上传的文件列表。
    group_name (str): 目标知识库分组名称,为空时不添加到分组。
    metadatas (Optional[str]): 文件的元数据,JSON格式。
**Returns:**

- BaseResponse:返回所有输入文件对应的唯一文件ID列表,包括新增和已存在的文件。若出现异常,则返回错误码和异常信息。
"""
        files = request.files
        group_name = request.group_name
        metadatas = request.metadatas
        try:
            if metadatas:
                metadatas: Optional[List[Dict[str, str]]] = json.loads(metadatas)
                assert len(files) == len(metadatas), 'Length of files and metadatas should be the same'

            exists_files_info = self._manager.list_files(limit=None, details=True, status=DocListManager.Status.all)
            exists_files_info = {row[2]: row[0] for row in exists_files_info}

            exist_ids = []
            new_files = []
            new_metadatas = []
            id_mapping = {}

            for idx, file in enumerate(files):
                if os.path.exists(file):
                    exist_id = exists_files_info.get(file, None)
                    if exist_id:
                        update_kws = dict(fileid=exist_id, status=DocListManager.Status.success)
                        if metadatas: update_kws["meta"] = json.dumps(metadatas[idx])
                        self._manager.update_file_message(**update_kws)
                        exist_ids.append(exist_id)
                        id_mapping[file] = exist_id
                    else:
                        new_files.append(file)
                        if metadatas:
                            new_metadatas.append(metadatas[idx])
                else:
                    id_mapping[file] = None

            new_ids = self._manager.add_files(new_files, metadatas=new_metadatas, status=DocListManager.Status.success)
            if group_name:
                self._manager.add_files_to_kb_group(new_ids + exist_ids, group=group_name)

            for file, new_id in zip(new_files, new_ids):
                id_mapping[file] = new_id
            return_ids = [id_mapping[file] for file in files]

            return BaseResponse(data=return_ids)
        except Exception as e:
            return BaseResponse(code=500, msg=str(e), data=None)

add_files_to_group(files, group_name, override=False, metadatas=None, user_path=None)

将文件上传后直接添加到指定分组的接口。

Parameters:

  • files (List[UploadFile]) –

    上传的文件列表。

  • group_name (str) –

    要添加到的分组名称。

  • override (bool, default: False ) –

    是否覆盖已存在的文件。默认为False。

  • metadatas (Optional[str], default: None ) –

    文件元数据,JSON格式。

  • user_path (Optional[str], default: None ) –

    用户自定义的文件上传路径。

Returns:

  • BaseResponse: 操作结果和文件ID。
Source code in lazyllm/tools/rag/doc_manager.py
    @app.post("/add_files_to_group")
    def add_files_to_group(self, files: List[UploadFile], group_name: str, override: bool = False,
                           metadatas: Optional[str] = None, user_path: Optional[str] = None):
        """
将文件上传后直接添加到指定分组的接口。

Args:
    files (List[UploadFile]): 上传的文件列表。
    group_name (str): 要添加到的分组名称。
    override (bool): 是否覆盖已存在的文件。默认为False。
    metadatas (Optional[str]): 文件元数据,JSON格式。
    user_path (Optional[str]): 用户自定义的文件上传路径。

**Returns:**

- BaseResponse: 操作结果和文件ID。
"""
        try:
            response = self.upload_files(files, override=override, metadatas=metadatas, user_path=user_path)
            if response.code != 200: return response
            ids = response.data[0]
            self._manager.add_files_to_kb_group(ids, group_name)
            return BaseResponse(data=ids)
        except Exception as e:
            return BaseResponse(code=500, msg=str(e), data=None)

add_files_to_group_by_id(request)

通过文件ID将文件添加到指定分组的接口。

Parameters:

  • request (FileGroupRequest) –

    包含文件ID和分组名称的请求。

Returns:

  • BaseResponse: 操作结果。
Source code in lazyllm/tools/rag/doc_manager.py
    @app.post("/add_files_to_group_by_id")
    def add_files_to_group_by_id(self, request: FileGroupRequest):
        """
通过文件ID将文件添加到指定分组的接口。

Args:
    request (FileGroupRequest): 包含文件ID和分组名称的请求。

**Returns:**

- BaseResponse: 操作结果。
"""
        try:
            self._manager.add_files_to_kb_group(request.file_ids, request.group_name)
            return BaseResponse()
        except Exception as e:
            return BaseResponse(code=500, msg=str(e), data=None)

add_metadata(add_metadata_request)

为指定文档添加或更新元数据的接口。

Parameters:

  • add_metadata_request (AddMetadataRequest) –

    包含文档ID列表和键值对元数据的请求。

Returns:

  • BaseResponse: 操作结果信息。
Source code in lazyllm/tools/rag/doc_manager.py
    @app.post("/add_metadata")
    def add_metadata(self, add_metadata_request: AddMetadataRequest):
        """
为指定文档添加或更新元数据的接口。

Args:
    add_metadata_request (AddMetadataRequest): 包含文档ID列表和键值对元数据的请求。

**Returns:**

- BaseResponse: 操作结果信息。
"""
        doc_ids = add_metadata_request.doc_ids
        kv_pair = add_metadata_request.kv_pair
        try:
            docs = self._manager.get_docs(doc_ids)
            if not docs:
                return BaseResponse(code=400, msg="Failed, no doc found")
            doc_meta = {}
            for doc in docs:
                meta_dict = json.loads(doc.meta) if doc.meta else {}
                for k, v in kv_pair.items():
                    if k not in meta_dict or not meta_dict[k]:
                        meta_dict[k] = v
                    elif isinstance(meta_dict[k], list):
                        meta_dict[k].extend(v) if isinstance(v, list) else meta_dict[k].append(v)
                    else:
                        meta_dict[k] = ([meta_dict[k]] + v) if isinstance(v, list) else [meta_dict[k], v]
                doc_meta[doc.doc_id] = meta_dict
            self._manager.set_docs_new_meta(doc_meta)
            return BaseResponse(data=None)
        except Exception as e:
            return BaseResponse(code=500, msg=str(e), data=None)

delete_files(request)

删除指定文件的接口。

Parameters:

  • request (FileGroupRequest) –

    包含文件ID和分组名称的请求。

Returns:

  • BaseResponse: 删除操作结果。
Source code in lazyllm/tools/rag/doc_manager.py
    @app.post("/delete_files")
    def delete_files(self, request: FileGroupRequest):
        """
删除指定文件的接口。

Args:
    request (FileGroupRequest): 包含文件ID和分组名称的请求。

**Returns:**

- BaseResponse: 删除操作结果。
"""
        try:
            if request.group_name:
                return self.delete_files_from_group(request)
            else:
                documents = self._manager.delete_files(request.file_ids)
                deleted_ids = set([ele.doc_id for ele in documents])
                for doc in documents:
                    if os.path.exists(path := doc.path):
                        os.remove(path)
                results = ["Success" if ele.doc_id in deleted_ids else "Failed" for ele in documents]
                return BaseResponse(data=[request.file_ids, results])
        except Exception as e:
            return BaseResponse(code=500, msg=str(e), data=None)

delete_files_from_group(request)

删除指定分组中的文件的接口。

Parameters:

  • request (FileGroupRequest) –

    包含文件ID列表和分组名称的请求参数。

Returns:

  • BaseResponse: 删除操作结果。
Source code in lazyllm/tools/rag/doc_manager.py
    @app.post("/delete_files_from_group")
    def delete_files_from_group(self, request: FileGroupRequest):
        """
删除指定分组中的文件的接口。

Args:
    request (FileGroupRequest): 包含文件ID列表和分组名称的请求参数。

**Returns:**

- BaseResponse: 删除操作结果。
"""
        try:
            self._manager.update_kb_group(cond_file_ids=request.file_ids, cond_group=request.group_name,
                                          new_status=DocListManager.Status.deleting)
            return BaseResponse()
        except Exception as e:
            return BaseResponse(code=500, msg=str(e), data=None)

delete_metadata_item(del_metadata_request)

删除指定文档的元数据字段或字段值的接口。

Parameters:

  • del_metadata_request (DeleteMetadataRequest) –

    包含文档ID列表、字段名和键值对删除条件的请求。

Returns:

  • BaseResponse: 操作结果信息。
Source code in lazyllm/tools/rag/doc_manager.py
    @app.post("/delete_metadata_item")
    def delete_metadata_item(self, del_metadata_request: DeleteMetadataRequest):
        """
删除指定文档的元数据字段或字段值的接口。

Args:
    del_metadata_request (DeleteMetadataRequest): 包含文档ID列表、字段名和键值对删除条件的请求。

**Returns:**

- BaseResponse: 操作结果信息。
"""
        doc_ids = del_metadata_request.doc_ids
        kv_pair = del_metadata_request.kv_pair
        keys = del_metadata_request.keys
        try:
            if keys is not None:
                # convert keys to kv_pair
                if kv_pair:
                    kv_pair.update({k: None for k in keys})
                else:
                    kv_pair = {k: None for k in keys}
            if not kv_pair:
                # clear metadata
                self._manager.set_docs_new_meta({doc_id: {} for doc_id in doc_ids})
            else:
                docs = self._manager.get_docs(doc_ids)
                if not docs:
                    return BaseResponse(code=400, msg="Failed, no doc found")
                doc_meta = {}
                for doc in docs:
                    meta_dict = json.loads(doc.meta) if doc.meta else {}
                    self._inplace_del_meta(meta_dict, kv_pair)
                    doc_meta[doc.doc_id] = meta_dict
                self._manager.set_docs_new_meta(doc_meta)
            return BaseResponse(data=None)
        except Exception as e:
            return BaseResponse(code=500, msg=str(e), data=None)

document()

提供默认文档页面的重定向接口。

Returns:

  • RedirectResponse: 重定向到 /docs 页面。
Source code in lazyllm/tools/rag/doc_manager.py
    @app.get("/", response_model=BaseResponse, summary="docs")
    def document(self):
        """
提供默认文档页面的重定向接口。

**Returns:**

- RedirectResponse: 重定向到 `/docs` 页面。
"""
        return RedirectResponse(url="/docs")

list_files(limit=None, details=True, alive=None)

列出已上传文件的接口。

Parameters:

  • limit (Optional[int], default: None ) –

    返回的文件数量限制。默认为None。

  • details (bool, default: True ) –

    是否返回详细信息。默认为True。

  • alive (Optional[bool], default: None ) –

    如果为True,只返回未删除的文件。默认为None。

Returns:

  • BaseResponse: 文件列表数据。
Source code in lazyllm/tools/rag/doc_manager.py
    @app.get("/list_files")
    def list_files(self, limit: Optional[int] = None, details: bool = True, alive: Optional[bool] = None):
        """
列出已上传文件的接口。

Args:
    limit (Optional[int]): 返回的文件数量限制。默认为None。
    details (bool): 是否返回详细信息。默认为True。
    alive (Optional[bool]): 如果为True,只返回未删除的文件。默认为None。

**Returns:**

- BaseResponse: 文件列表数据。
"""
        try:
            status = [DocListManager.Status.success, DocListManager.Status.waiting, DocListManager.Status.working,
                      DocListManager.Status.failed] if alive else DocListManager.Status.all
            return BaseResponse(data=self._manager.list_files(limit=limit, details=details, status=status))
        except Exception as e:
            return BaseResponse(code=500, msg=str(e), data=None)

list_files_in_group(group_name=None, limit=None, alive=None)

列出指定分组中文件的接口。

Parameters:

  • group_name (Optional[str], default: None ) –

    文件分组名称。

  • limit (Optional[int], default: None ) –

    返回的文件数量限制。默认为None。

  • alive (Optional[bool], default: None ) –

    是否只返回未删除的文件。

Returns:

  • BaseResponse: 分组文件列表。
Source code in lazyllm/tools/rag/doc_manager.py
    @app.get("/list_files_in_group")
    def list_files_in_group(self, group_name: Optional[str] = None,
                            limit: Optional[int] = None, alive: Optional[bool] = None):
        """
列出指定分组中文件的接口。

Args:
    group_name (Optional[str]): 文件分组名称。
    limit (Optional[int]): 返回的文件数量限制。默认为None。
    alive (Optional[bool]): 是否只返回未删除的文件。

**Returns:**

- BaseResponse: 分组文件列表。
"""
        try:
            status = [DocListManager.Status.success, DocListManager.Status.waiting, DocListManager.Status.working,
                      DocListManager.Status.failed] if alive else DocListManager.Status.all
            return BaseResponse(data=self._manager.list_kb_group_files(group_name, limit, details=True, status=status))
        except Exception as e:
            return BaseResponse(code=500, msg=str(e) + '\ntraceback:\n' + str(traceback.format_exc()), data=None)

list_kb_groups()

列出所有文档分组的接口。

Returns:

  • BaseResponse: 包含所有文档分组的数据。
Source code in lazyllm/tools/rag/doc_manager.py
    @app.get("/list_kb_groups")
    def list_kb_groups(self):
        """
列出所有文档分组的接口。

**Returns:**

- BaseResponse: 包含所有文档分组的数据。
"""
        try:
            return BaseResponse(data=self._manager.list_all_kb_group())
        except Exception as e:
            return BaseResponse(code=500, msg=str(e), data=None)

query_metadata(query_metadata_request)

查询指定文档的元数据。

Parameters:

  • query_metadata_request (QueryMetadataRequest) –

    请求参数,包含文档ID和可选的字段名。

Returns:

  • BaseResponse: 若指定了 key 且存在,返回对应字段值;否则返回整个 metadata;key 不存在时报错。
Source code in lazyllm/tools/rag/doc_manager.py
    @app.post("/query_metadata")
    def query_metadata(self, query_metadata_request: QueryMetadataRequest):
        """
查询指定文档的元数据。

Args:
    query_metadata_request (QueryMetadataRequest): 请求参数,包含文档ID和可选的字段名。

**Returns:**

- BaseResponse: 若指定了 key 且存在,返回对应字段值;否则返回整个 metadata;key 不存在时报错。
"""
        doc_id = query_metadata_request.doc_id
        key = query_metadata_request.key
        try:
            docs = self._manager.get_docs(doc_id)
            if not docs:
                return BaseResponse(data=None)
            doc = docs[0]
            meta_dict = json.loads(doc.meta) if doc.meta else {}
            if not key:
                return BaseResponse(data=meta_dict)
            if key not in meta_dict:
                return BaseResponse(code=400, msg=f"Failed, key {key} does not exist")
            return BaseResponse(data=meta_dict[key])
        except Exception as e:
            return BaseResponse(code=500, msg=str(e), data=None)

reset_metadata(reset_metadata_request)

重置指定文档的所有元数据字段。

Parameters:

  • reset_metadata_request (ResetMetadataRequest) –

    包含文档ID列表和新的元数据字典。

Returns:

  • BaseResponse: 操作结果信息。
Source code in lazyllm/tools/rag/doc_manager.py
    @app.post("/reset_metadata")
    def reset_metadata(self, reset_metadata_request: ResetMetadataRequest):
        """
重置指定文档的所有元数据字段。

Args:
    reset_metadata_request (ResetMetadataRequest): 包含文档ID列表和新的元数据字典。

**Returns:**

- BaseResponse: 操作结果信息。
"""
        doc_ids = reset_metadata_request.doc_ids
        new_meta = reset_metadata_request.new_meta
        try:
            docs = self._manager.get_docs(doc_ids)
            if not docs:
                return BaseResponse(code=400, msg="Failed, no doc found")
            self._manager.set_docs_new_meta({doc.doc_id: new_meta for doc in docs})
            return BaseResponse(data=None)
        except Exception as e:
            return BaseResponse(code=500, msg=str(e), data=None)

update_or_create_metadata_keys(update_metadata_request)

更新或创建文档元数据字段的接口。 Args: update_metadata_request (UpdateMetadataRequest): 包含文档ID列表和需更新或新增的键值对元数据。

Returns:

  • BaseResponse: 操作结果信息。
Source code in lazyllm/tools/rag/doc_manager.py
    @app.post("/update_or_create_metadata_keys")
    def update_or_create_metadata_keys(self, update_metadata_request: UpdateMetadataRequest):
        """
更新或创建文档元数据字段的接口。
Args:
    update_metadata_request (UpdateMetadataRequest): 包含文档ID列表和需更新或新增的键值对元数据。

**Returns:**

- BaseResponse: 操作结果信息。
"""
        doc_ids = update_metadata_request.doc_ids
        kv_pair = update_metadata_request.kv_pair
        try:
            docs = self._manager.get_docs(doc_ids)
            if not docs:
                return BaseResponse(code=400, msg="Failed, no doc found")
            for doc in docs:
                doc_meta = {}
                meta_dict = json.loads(doc.meta) if doc.meta else {}
                for k, v in kv_pair.items():
                    meta_dict[k] = v
                doc_meta[doc.doc_id] = meta_dict
            self._manager.set_docs_new_meta(doc_meta)
            return BaseResponse(data=None)
        except Exception as e:
            return BaseResponse(code=500, msg=str(e), data=None)

upload_files(files, override=False, metadatas=None, user_path=None)

上传文件并更新其状态的接口。可以同时上传多个文件。

Parameters:

  • files (List[UploadFile]) –

    上传的文件列表。

  • override (bool, default: False ) –

    是否覆盖已存在的文件。默认为False。

  • metadatas (Optional[str], default: None ) –

    文件的元数据,JSON格式。

  • user_path (Optional[str], default: None ) –

    用户自定义的文件上传路径。

Returns:

  • BaseResponse: 上传结果和文件ID。
Source code in lazyllm/tools/rag/doc_manager.py
    @app.post("/upload_files")
    def upload_files(self, files: List[UploadFile], override: bool = False,  # noqa C901
                     metadatas: Optional[str] = None, user_path: Optional[str] = None):
        """
上传文件并更新其状态的接口。可以同时上传多个文件。

Args:
    files (List[UploadFile]): 上传的文件列表。
    override (bool): 是否覆盖已存在的文件。默认为False。
    metadatas (Optional[str]): 文件的元数据,JSON格式。
    user_path (Optional[str]): 用户自定义的文件上传路径。

**Returns:**

- BaseResponse: 上传结果和文件ID。
"""
        try:
            if user_path: user_path = user_path.lstrip('/')
            if metadatas:
                metadatas: Optional[List[Dict[str, str]]] = json.loads(metadatas)
                if len(files) != len(metadatas):
                    return BaseResponse(code=400, msg='Length of files and metadatas should be the same',
                                        data=None)
                for idx, mt in enumerate(metadatas):
                    err_msg = self._validate_metadata(mt)
                    if err_msg:
                        return BaseResponse(code=400, msg=f'file [{files[idx].filename}]: {err_msg}', data=None)
            file_paths = [os.path.join(self._manager._path, user_path or '', file.filename) for file in files]
            paths_is_new = [True] * len(file_paths)
            if override is True:
                is_success, msg, paths_is_new = self._manager.validate_paths(file_paths)
                if not is_success:
                    return BaseResponse(code=500, msg=msg, data=None)
            directorys = set(os.path.dirname(path) for path in file_paths)
            [os.makedirs(directory, exist_ok=True) for directory in directorys if directory]
            ids, results = [], []
            for i in range(len(files)):
                file_path = file_paths[i]
                content = files[i].file.read()
                metadata = metadatas[i] if metadatas else None
                if override is False:
                    file_path = self._gen_unique_filepath(file_path)
                with open(file_path, 'wb') as f: f.write(content)
                msg = "success"
                doc_id = gen_docid(file_path)
                if paths_is_new[i]:
                    docs = self._manager.add_files(
                        [file_path], metadatas=[metadata], status=DocListManager.Status.success)
                    if not docs:
                        msg = f"Failed: path {file_path} already exists in Database."
                else:
                    self._manager.update_kb_group(cond_file_ids=[doc_id], new_need_reparse=True)
                    msg = f"Success: path {file_path} will be reparsed."
                ids.append(doc_id)
                results.append(msg)
            return BaseResponse(data=[ids, results])
        except Exception as e:
            lazyllm.LOG.error(f'upload_files exception: {e}')
            return BaseResponse(code=500, msg=str(e), data=None)

lazyllm.tools.rag.utils.SqliteDocListManager

Bases: DocListManager

基于 SQLite 的文档管理器,用于本地文件的持久化存储、状态管理与元信息追踪。

该类继承自 DocListManager,利用 SQLite 数据库存储文档记录。适用于管理具有唯一标识符的本地文档资源,并提供便捷的插入、查询、更新与状态过滤接口,支持可选的路径监控功能。

Parameters:

  • path (str) –

    数据库存储路径。

  • name (str) –

    数据库文件名(不包含路径)。

  • enable_path_monitoring (bool, default: True ) –

    是否启用对文件路径的变动监控,默认为 True。

Examples:

>>> from lazyllm.tools.rag.utils import SqliteDocListManager
>>> manager = SqliteDocListManager(path="./data", name="docs.sqlite")
>>> manager.insert({"uid": "doc_001", "name": "example.txt", "status": "ready"})
>>> print(manager.get("doc_001"))
>>> files = manager.list_files(limit=5, details=True)
>>> print(files)
Source code in lazyllm/tools/rag/utils.py
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
class SqliteDocListManager(DocListManager):
    """基于 SQLite 的文档管理器,用于本地文件的持久化存储、状态管理与元信息追踪。

该类继承自 DocListManager,利用 SQLite 数据库存储文档记录。适用于管理具有唯一标识符的本地文档资源,并提供便捷的插入、查询、更新与状态过滤接口,支持可选的路径监控功能。

Args:
    path (str): 数据库存储路径。
    name (str): 数据库文件名(不包含路径)。
    enable_path_monitoring (bool): 是否启用对文件路径的变动监控,默认为 True。


Examples:
    >>> from lazyllm.tools.rag.utils import SqliteDocListManager
    >>> manager = SqliteDocListManager(path="./data", name="docs.sqlite")
    >>> manager.insert({"uid": "doc_001", "name": "example.txt", "status": "ready"})
    >>> print(manager.get("doc_001"))
    >>> files = manager.list_files(limit=5, details=True)
    >>> print(files)
    """
    def __init__(self, path, name, enable_path_monitoring=True):
        super().__init__(path, name, enable_path_monitoring)

    def _init_sql(self):
        root_dir = os.path.expanduser(os.path.join(config['home'], '.dbs'))
        os.makedirs(root_dir, exist_ok=True)
        self._db_path = os.path.join(root_dir, f'.lazyllm_dlmanager.{self._id}.db')
        self._db_lock = FileLock(self._db_path + '.lock')
        # ensure that this connection is not used in another thread when sqlite3 is not threadsafe
        self._check_same_thread = not sqlite3_check_threadsafety()
        self._engine = sqlalchemy.create_engine(
            f"sqlite:///{self._db_path}?check_same_thread={self._check_same_thread}"
        )
        self._Session = sessionmaker(bind=self._engine)
        self.init_tables()

    def _init_tables(self):
        KBDataBase.metadata.create_all(bind=self._engine)

    def table_inited(self):
        """检查数据库中是否已存在名为 "documents" 的表。

该方法通过查询 sqlite_master 元信息表,判断数据表是否已初始化。

**Returns:**

- bool: 如果 "documents" 表存在,返回 True;否则返回 False。
"""
        with self._db_lock, sqlite3.connect(self._db_path, check_same_thread=self._check_same_thread) as conn:
            cursor = conn.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='documents'")
            return cursor.fetchone() is not None

    @staticmethod
    def get_status_cond_and_params(status: Union[str, List[str]],
                                   exclude_status: Optional[Union[str, List[str]]] = None,
                                   prefix: str = None):
        conds, params = [], []
        prefix = f'{prefix}.' if prefix else ''
        if isinstance(status, str):
            if status != DocListManager.Status.all:
                conds.append(f'{prefix}status = ?')
                params.append(status)
        elif isinstance(status, (tuple, list)):
            conds.append(f'{prefix}status IN ({",".join("?" * len(status))})')
            params.extend(status)

        if isinstance(exclude_status, str):
            assert exclude_status != DocListManager.Status.all, 'Invalid status provided'
            conds.append(f'{prefix}status != ?')
            params.append(exclude_status)
        elif isinstance(exclude_status, (tuple, list)):
            conds.append(f'{prefix}status NOT IN ({",".join("?" * len(exclude_status))})')
            params.extend(exclude_status)

        return ' AND '.join(conds), params

    def _get_all_docs(self):
        with self._db_lock, self._Session() as session:
            return session.query(KBDocument).all()

    def _get_docs(self, to_be_added_doc_ids: List, to_be_deleted_doc_ids: List, filter_status_list: List):
        with self._db_lock, self._Session() as session:
            docs_not_expected = session.query(KBDocument).filter(KBDocument.doc_id.in_(to_be_added_doc_ids)).all()
            docs_expected = session.query(KBDocument).filter(KBDocument.doc_id.in_(to_be_deleted_doc_ids),
                                                             KBDocument.status.in_(filter_status_list)).all()
        return docs_not_expected, docs_expected

    def validate_paths(self, paths: List[str]) -> Tuple[bool, str, List[bool]]:
        """验证输入路径所对应的文档是否可以安全添加到数据库。

该方法会检查每个路径是否对应已有文档,若已存在,需判断其状态是否允许重解析。
若文档正在解析或等待解析,或上次重解析未完成,则视为不可用。

Args:
    paths (List[str]): 文件路径列表。

**Returns:**

- Tuple[bool, str, List[bool]]: 
    - bool: 是否所有路径都验证通过。
    - str: 成功或失败的描述信息。
    - List[bool]: 与输入路径一一对应的布尔列表,表示该路径是否为新文档(True 为新文档,False 为已存在)。
        若验证失败,返回值为 None。
"""
        # check and return: success, msg, path_is_new for each path
        unsafe_staus_set = set([DocListManager.Status.working, DocListManager.Status.waiting])
        paths_is_new = [True] * len(paths)
        doc_ids = [gen_docid(path) for path in paths]
        doc_id_to_path = {doc_id: path for doc_id, path in zip(doc_ids, paths)}
        found_doc_ids = []
        found_doc_group_rows = []
        with self._db_lock, self._Session() as session:
            rows = session.execute(
                select(KBDocument.doc_id).where(KBDocument.doc_id.in_(doc_ids))
            ).fetchall()
            if len(rows) == 0:
                return True, "Success", paths_is_new
            found_doc_ids = [row.doc_id for row in rows]
            found_doc_group_rows = session.execute(
                select(KBGroupDocuments.doc_id, KBGroupDocuments.need_reparse, KBGroupDocuments.status)
                .where(KBGroupDocuments.doc_id.in_(found_doc_ids))).fetchall()

        for doc_group_record in found_doc_group_rows:
            if doc_group_record.need_reparse:
                msg = f"Failed: {doc_id_to_path[doc_group_record.doc_id]} lasttime reparsing has not been finished"
                return False, msg, None
            if doc_group_record.status in unsafe_staus_set:
                return False, f"Failed: {doc_id_to_path[doc_group_record.doc_id]} is being parsed by kbgroup", None
        found_doc_ids = set(found_doc_ids)
        for i in range(len(paths)):
            cur_doc_id = doc_ids[i]
            if cur_doc_id in found_doc_ids:
                paths_is_new[i] = False
        return True, "Success", paths_is_new

    def update_need_reparsing(self, doc_id: str, need_reparse: bool, group_name: Optional[str] = None):
        """更新指定文档的重解析标志位。

该方法用于设置某个文档是否需要重新解析。可以选择性地指定知识库分组进行精确匹配。

Args:
    doc_id (str): 文档的唯一标识符。
    need_reparse (bool): 是否需要重新解析文档。
    group_name (Optional[str]): 可选,所属的知识库分组名称。如果提供,将仅更新指定分组中的文档。
"""
        with self._db_lock, self._Session() as session:
            stmt = update(KBGroupDocuments).where(KBGroupDocuments.doc_id == doc_id)
            if group_name is not None: stmt = stmt.where(KBGroupDocuments.group_name == group_name)
            session.execute(stmt.values(need_reparse=need_reparse))
            session.commit()

    def list_files(self, limit: Optional[int] = None, details: bool = False,
                   status: Union[str, List[str]] = DocListManager.Status.all,
                   exclude_status: Optional[Union[str, List[str]]] = None):
        """列出文档数据库中符合状态条件的文件,并根据参数选择返回完整记录或仅返回文件路径。

Args:
    limit (Optional[int]): 要返回的记录数上限,若为 None 则返回所有符合条件的记录。
    details (bool): 是否返回完整的数据库行信息,若为 False 则仅返回文档路径(ID)。
    status (Union[str, List[str]]): 要包含在结果中的状态值,默认为包含所有状态。
    exclude_status (Optional[Union[str, List[str]]]): 要从结果中排除的状态值。

**Returns:**

- list: 文件记录列表或文档路径列表,具体取决于 `details` 参数。
"""
        query = "SELECT * FROM documents"
        params = []
        status_cond, status_params = self.get_status_cond_and_params(status, exclude_status, prefix=None)
        if status_cond:
            query += f' WHERE {status_cond}'
            params.extend(status_params)
        if limit:
            query += " LIMIT ?"
            params.append(limit)
        with self._db_lock, sqlite3.connect(self._db_path, check_same_thread=self._check_same_thread) as conn:
            cursor = conn.execute(query, params)
            return cursor.fetchall() if details else [row[0] for row in cursor]

    def get_docs(self, doc_ids: List[str]) -> List[KBDocument]:
        """根据给定的文档ID列表,从数据库中获取对应的文档对象列表。

Args:
    doc_ids (List[str]): 需要查询的文档ID列表。

**Returns:**

- List[KBDocument]: 匹配的文档对象列表。如果没有匹配项,返回空列表。
"""
        with self._db_lock, self._Session() as session:
            docs = session.query(KBDocument).filter(KBDocument.doc_id.in_(doc_ids)).all()
            return docs
        return []

    def set_docs_new_meta(self, doc_meta: Dict[str, dict]):
        """批量更新文档的元数据(meta),同时更新对应知识库分组中文档的 new_meta 字段(非等待状态的文档)。

Args:
    doc_meta (Dict[str, dict]): 字典,键为文档ID,值为对应的新元数据字典。
"""
        data_to_update = [{"_doc_id": k, "_meta": json.dumps(v)} for k, v in doc_meta.items()]
        with self._db_lock, self._Session() as session:
            # Use sqlalchemy core bulk update
            stmt = KBDocument.__table__.update().where(
                KBDocument.doc_id == bindparam("_doc_id")).values(meta=bindparam("_meta"))
            session.execute(stmt, data_to_update)
            session.commit()

            stmt = KBGroupDocuments.__table__.update().where(
                KBGroupDocuments.doc_id == bindparam("_doc_id"),
                KBGroupDocuments.status != DocListManager.Status.waiting).values(new_meta=bindparam("_meta"))
            session.execute(stmt, data_to_update)
            session.commit()

    def fetch_docs_changed_meta(self, group: str) -> List[DocMetaChangedRow]:
        """获取指定知识库分组中元数据发生变化的文档列表,并将对应的 new_meta 字段清空。

Args:
    group (str): 知识库分组名称。

**Returns:**

- List[DocMetaChangedRow]: 包含文档ID及其对应新元数据的列表。
"""
        rows = []
        conds = [KBGroupDocuments.group_name == group, KBGroupDocuments.new_meta.isnot(None)]
        with self._db_lock, self._Session() as session:
            rows = (
                session.query(KBDocument.doc_id, KBGroupDocuments.new_meta)
                .join(KBGroupDocuments, KBDocument.doc_id == KBGroupDocuments.doc_id)
                .filter(*conds).all()
            )
            stmt = update(KBGroupDocuments).where(sqlalchemy.and_(*conds)).values(new_meta=None)
            session.execute(stmt)
            session.commit()
        return rows

    def list_all_kb_group(self):
        """列出数据库中所有的知识库分组名称。

**Returns:**

- List[str]: 知识库分组名称列表。
"""
        with self._db_lock, sqlite3.connect(self._db_path, check_same_thread=self._check_same_thread) as conn:
            cursor = conn.execute("SELECT group_name FROM document_groups")
            return [row[0] for row in cursor]

    def add_kb_group(self, name):
        """向数据库中添加一个新的知识库分组名称,若已存在则忽略。

Args:
    name (str): 要添加的知识库分组名称。
"""
        with self._db_lock, sqlite3.connect(self._db_path, check_same_thread=self._check_same_thread) as conn:
            conn.execute('INSERT OR IGNORE INTO document_groups (group_name) VALUES (?)', (name,))
            conn.commit()

    def list_kb_group_files(self, group: str = None, limit: Optional[int] = None, details: bool = False,
                            status: Union[str, List[str]] = DocListManager.Status.all,
                            exclude_status: Optional[Union[str, List[str]]] = None,
                            upload_status: Union[str, List[str]] = DocListManager.Status.all,
                            exclude_upload_status: Optional[Union[str, List[str]]] = None,
                            need_reparse: Optional[bool] = None):
        """列出指定知识库分组中的文件信息,可根据多种条件进行过滤。

Args:
    group (str, optional): 知识库分组名称,若为 None 则不按分组过滤。
    limit (int, optional): 限制返回的文件数量。
    details (bool): 是否返回详细的文件信息。
    status (str or List[str], optional): 过滤知识库分组中文件的状态。
    exclude_status (str or List[str], optional): 排除指定状态的文件。
    upload_status (str or List[str], optional): 过滤文件上传状态。
    exclude_upload_status (str or List[str], optional): 排除指定的上传状态。
    need_reparse (bool, optional): 是否只返回需要重新解析的文件。

**Returns:**

- list: 
    - 如果 details 为 False,返回列表,每个元素为 (doc_id, path) 元组。
    - 如果 details 为 True,返回包含文件详细信息的元组列表,包括文档ID、路径、状态、元数据,
      知识库分组名、分组内状态及日志。
"""
        query = """
            SELECT documents.doc_id, documents.path, documents.status, documents.meta,
                   kb_group_documents.group_name, kb_group_documents.status, kb_group_documents.log
            FROM kb_group_documents
            JOIN documents ON kb_group_documents.doc_id = documents.doc_id
        """
        conds, params = [], []
        if group:
            conds.append('kb_group_documents.group_name = ?')
            params.append(group)

        if need_reparse is not None:
            conds.append('kb_group_documents.need_reparse = ?')
            params.append(int(need_reparse))

        status_cond, status_params = self.get_status_cond_and_params(status, exclude_status, prefix='kb_group_documents')
        if status_cond:
            conds.append(status_cond)
            params.extend(status_params)

        status_cond, status_params = self.get_status_cond_and_params(
            upload_status, exclude_upload_status, prefix='documents')
        if status_cond:
            conds.append(status_cond)
            params.extend(status_params)

        if conds: query += ' WHERE ' + ' AND '.join(conds)

        if limit:
            query += ' LIMIT ?'
            params.append(limit)

        with self._db_lock, sqlite3.connect(self._db_path, check_same_thread=self._check_same_thread) as conn:
            cursor = conn.execute(query, params)
            rows = cursor.fetchall()

        if not details: return [row[:2] for row in rows]
        return rows

    def delete_unreferenced_doc(self):
        """删除数据库中标记为删除且未被任何知识库分组引用的文档记录。

该方法会查找状态为“deleting”且引用计数为0的文档,删除这些文档记录,并记录删除操作日志。

"""
        with self._db_lock, self._Session() as session:
            docs_to_delete = (
                session.query(KBDocument)
                .filter(KBDocument.status == DocListManager.Status.deleting, KBDocument.count == 0)
                .all()
            )
            for doc in docs_to_delete:
                session.delete(doc)
                log = KBOperationLogs(log=f"Delete obsolete file, doc_id:{doc.doc_id}, path:{doc.path}.")
                session.add(log)
            session.commit()

    def _add_doc_records(self, files: List[str], metadatas: Optional[List[Dict[str, Any]]] = None,
                         status: Optional[str] = DocListManager.Status.waiting, batch_size: int = 64):
        documents = []

        for i in range(0, len(files), batch_size):
            batch_files = files[i:i + batch_size]
            batch_metadatas = metadatas[i:i + batch_size] if metadatas else [None] * batch_size
            vals = []

            for i, file_path in enumerate(batch_files):
                doc_id = gen_docid(file_path)

                metadata = batch_metadatas[i].copy() if batch_metadatas[i] else {}
                metadata.setdefault(RAG_DOC_ID, doc_id)
                metadata.setdefault(RAG_DOC_PATH, file_path)

                vals.append(
                    {
                        KBDocument.doc_id.name: doc_id,
                        KBDocument.filename.name: os.path.basename(file_path),
                        KBDocument.path.name: file_path,
                        KBDocument.meta.name: json.dumps(metadata),
                        KBDocument.status.name: status,
                        KBDocument.count.name: 0,
                    }
                )
            with self._db_lock, self._Session() as session:
                rows = session.execute(
                    insert(KBDocument)
                    .values(vals)
                    .prefix_with('OR IGNORE')
                    .returning(KBDocument.doc_id, KBDocument.path)
                ).fetchall()
                session.commit()
                documents.extend(rows)
        return documents

    def get_docs_need_reparse(self, group: str) -> List[KBDocument]:
        """获取指定知识库分组中需要重新解析的文档列表。

仅返回状态为“success”或“failed”的文档,且其对应的知识库分组记录标记为需要重新解析。

Args:
    group (str): 知识库分组名称。

**Returns:**

- List[KBDocument]: 需要重新解析的文档列表。
"""
        with self._db_lock, self._Session() as session:
            filter_status_list = [DocListManager.Status.success, DocListManager.Status.failed]
            documents = (
                session.query(KBDocument).join(KBGroupDocuments, KBDocument.doc_id == KBGroupDocuments.doc_id)
                .filter(KBGroupDocuments.need_reparse.is_(True),
                        KBGroupDocuments.group_name == group,
                        KBGroupDocuments.status.in_(filter_status_list)).all())
            return documents
        return []

    def get_existing_paths_by_pattern(self, pattern: str) -> List[str]:
        """根据路径匹配模式获取已存在的文档路径列表。

Args:
    pattern (str): 路径匹配模式,支持SQL的LIKE通配符。

**Returns:**

- List[str]: 匹配到的已存在文档路径列表。
"""
        exist_paths = []
        with self._db_lock, self._Session() as session:
            docs = session.query(KBDocument).filter(KBDocument.path.like(pattern)).all()
            exist_paths = [doc.path for doc in docs]
        return exist_paths

    # TODO(wangzhihong): set to metadatas and enable this function
    def update_file_message(self, fileid: str, **kw):
        """更新指定文件的字段信息。

Args:
    fileid (str): 文件的唯一标识符(doc_id)。
    **kw: 需要更新的字段及其对应的值,键值对形式传入。
"""
        set_clause = ", ".join([f"{k} = ?" for k in kw.keys()])
        params = list(kw.values()) + [fileid]
        with self._db_lock, sqlite3.connect(self._db_path, check_same_thread=self._check_same_thread) as conn:
            conn.execute(f"UPDATE documents SET {set_clause} WHERE doc_id = ?", params)
            conn.commit()

    def update_file_status(self, file_ids: List[str], status: str,
                           cond_status_list: Union[None, List[str]] = None) -> List[DocPartRow]:
        """更新多个文件的状态,支持根据当前状态进行条件过滤。

Args:
    file_ids (List[str]): 需要更新状态的文件ID列表。
    status (str): 要设置的新状态。
    cond_status_list (Union[None, List[str]], optional): 仅更新当前状态在此列表中的文件,默认为 None,表示不筛选。

**Returns:**

- List[DocPartRow]: 返回更新后的文件ID和路径列表。
"""
        rows = []
        if cond_status_list is None:
            sql_cond = KBDocument.doc_id.in_(file_ids)
        else:
            sql_cond = sqlalchemy.and_(KBDocument.status.in_(cond_status_list), KBDocument.doc_id.in_(file_ids))
        with self._db_lock, self._Session() as session:
            stmt = (
                update(KBDocument)
                .where(sql_cond)
                .values(status=status)
                .returning(KBDocument.doc_id, KBDocument.path)
            )
            rows = session.execute(stmt).fetchall()
            session.commit()
        return rows

    def add_files_to_kb_group(self, file_ids: List[str], group: str):
        """将多个文件添加到指定的知识库分组中。

该方法会将文件状态设置为等待处理(waiting),
若添加成功,则对应文档的计数(count)加一。

Args:
    file_ids (List[str]): 需要添加的文件ID列表。
    group (str): 知识库分组名称。
"""
        with self._db_lock, self._Session() as session:
            vals = []
            for doc_id in file_ids:
                vals = {
                    KBGroupDocuments.doc_id.name: doc_id,
                    KBGroupDocuments.group_name.name: group,
                    KBGroupDocuments.status.name: DocListManager.Status.waiting,
                }
                rows = session.execute(
                    insert(KBGroupDocuments).values(vals).prefix_with('OR IGNORE').returning(KBGroupDocuments.doc_id)
                ).fetchall()
                session.commit()
                if not rows:
                    continue
                doc = session.query(KBDocument).filter_by(doc_id=rows[0].doc_id).one()
                doc.count += 1
                session.commit()

    def delete_files_from_kb_group(self, file_ids: List[str], group: str):
        """从指定的知识库分组中删除多个文件。

删除成功后,对应文档的计数(count)减少,但不会低于0。
若文档不存在,会记录警告日志。

Args:
    file_ids (List[str]): 需要删除的文件ID列表。
    group (str): 知识库分组名称。
"""
        with self._db_lock, self._Session() as session:
            for doc_id in file_ids:
                records_to_delete = (
                    session.query(KBGroupDocuments)
                    .filter(KBGroupDocuments.doc_id == doc_id, KBGroupDocuments.group_name == group)
                    .all()
                )
                for record in records_to_delete:
                    session.delete(record)
                session.commit()
                if not records_to_delete:
                    continue
                try:
                    doc = session.query(KBDocument).filter_by(doc_id=records_to_delete[0].doc_id).one()
                    doc.count = max(0, doc.count - 1)
                    session.commit()
                except NoResultFound:
                    lazyllm.LOG.warning(f"No document found for {doc_id}")

    def get_file_status(self, fileid: str):
        """获取指定文件的状态。

Args:
    fileid (str): 文件的唯一标识符。

**Returns:**

- Optional[Tuple]: 返回包含状态的元组,若文件不存在则返回 None。
"""
        with self._db_lock, sqlite3.connect(self._db_path, check_same_thread=self._check_same_thread) as conn:
            cursor = conn.execute("SELECT status FROM documents WHERE doc_id = ?", (fileid,))
        return cursor.fetchone()

    def update_kb_group(self, cond_file_ids: List[str], cond_group: Optional[str] = None,
                        cond_status_list: Optional[List[str]] = None, new_status: Optional[str] = None,
                        new_need_reparse: Optional[bool] = None) -> List[GroupDocPartRow]:
        """更新知识库分组中指定文件的状态和重解析需求。

根据给定的文件ID列表、分组名及状态列表,批量更新对应文件在知识库分组中的状态及是否需要重解析标志。

Args:
    cond_file_ids (List[str]): 需要更新的文件ID列表。
    cond_group (Optional[str]): 分组名称,若指定则只更新该分组内的文件。
    cond_status_list (Optional[List[str]]): 仅更新状态匹配此列表的文件。
    new_status (Optional[str]): 新的文件状态。
    new_need_reparse (Optional[bool]): 新的重解析需求标志。

**Returns:**

- List[Tuple]: 返回更新后文件的doc_id、group_name及状态列表。
"""
        rows = []
        conds = []
        if not cond_file_ids:
            return rows
        conds.append(KBGroupDocuments.doc_id.in_(cond_file_ids))
        if cond_group is not None:
            conds.append(KBGroupDocuments.group_name == cond_group)
        if cond_status_list:
            conds.append(KBGroupDocuments.status.in_(cond_status_list))

        vals = {}
        if new_status is not None:
            vals[KBGroupDocuments.status.name] = new_status
        if new_need_reparse is not None:
            vals[KBGroupDocuments.need_reparse.name] = new_need_reparse

        if not vals:
            return rows
        with self._db_lock, self._Session() as session:
            stmt = (
                update(KBGroupDocuments)
                .where(sqlalchemy.and_(*conds))
                .values(vals)
                .returning(KBGroupDocuments.doc_id, KBGroupDocuments.group_name, KBGroupDocuments.status)
            )
            rows = session.execute(stmt).fetchall()
            session.commit()
        return rows

    def release(self):
        """清空数据库中的所有文档、分组及相关操作日志数据。

该操作会删除 documents、document_groups、kb_group_documents 和 operation_logs 表中的所有记录。

"""
        with self._db_lock, sqlite3.connect(self._db_path, check_same_thread=self._check_same_thread) as conn:
            conn.execute('delete from documents')
            conn.execute('delete from document_groups')
            conn.execute('delete from kb_group_documents')
            conn.execute('delete from operation_logs')
            conn.commit()

    def __reduce__(self):
        return (__class__, (self._path, self._name, self._enable_path_monitoring))

add_files_to_kb_group(file_ids, group)

将多个文件添加到指定的知识库分组中。

该方法会将文件状态设置为等待处理(waiting), 若添加成功,则对应文档的计数(count)加一。

Parameters:

  • file_ids (List[str]) –

    需要添加的文件ID列表。

  • group (str) –

    知识库分组名称。

Source code in lazyllm/tools/rag/utils.py
    def add_files_to_kb_group(self, file_ids: List[str], group: str):
        """将多个文件添加到指定的知识库分组中。

该方法会将文件状态设置为等待处理(waiting),
若添加成功,则对应文档的计数(count)加一。

Args:
    file_ids (List[str]): 需要添加的文件ID列表。
    group (str): 知识库分组名称。
"""
        with self._db_lock, self._Session() as session:
            vals = []
            for doc_id in file_ids:
                vals = {
                    KBGroupDocuments.doc_id.name: doc_id,
                    KBGroupDocuments.group_name.name: group,
                    KBGroupDocuments.status.name: DocListManager.Status.waiting,
                }
                rows = session.execute(
                    insert(KBGroupDocuments).values(vals).prefix_with('OR IGNORE').returning(KBGroupDocuments.doc_id)
                ).fetchall()
                session.commit()
                if not rows:
                    continue
                doc = session.query(KBDocument).filter_by(doc_id=rows[0].doc_id).one()
                doc.count += 1
                session.commit()

add_kb_group(name)

向数据库中添加一个新的知识库分组名称,若已存在则忽略。

Parameters:

  • name (str) –

    要添加的知识库分组名称。

Source code in lazyllm/tools/rag/utils.py
    def add_kb_group(self, name):
        """向数据库中添加一个新的知识库分组名称,若已存在则忽略。

Args:
    name (str): 要添加的知识库分组名称。
"""
        with self._db_lock, sqlite3.connect(self._db_path, check_same_thread=self._check_same_thread) as conn:
            conn.execute('INSERT OR IGNORE INTO document_groups (group_name) VALUES (?)', (name,))
            conn.commit()

delete_files_from_kb_group(file_ids, group)

从指定的知识库分组中删除多个文件。

删除成功后,对应文档的计数(count)减少,但不会低于0。 若文档不存在,会记录警告日志。

Parameters:

  • file_ids (List[str]) –

    需要删除的文件ID列表。

  • group (str) –

    知识库分组名称。

Source code in lazyllm/tools/rag/utils.py
    def delete_files_from_kb_group(self, file_ids: List[str], group: str):
        """从指定的知识库分组中删除多个文件。

删除成功后,对应文档的计数(count)减少,但不会低于0。
若文档不存在,会记录警告日志。

Args:
    file_ids (List[str]): 需要删除的文件ID列表。
    group (str): 知识库分组名称。
"""
        with self._db_lock, self._Session() as session:
            for doc_id in file_ids:
                records_to_delete = (
                    session.query(KBGroupDocuments)
                    .filter(KBGroupDocuments.doc_id == doc_id, KBGroupDocuments.group_name == group)
                    .all()
                )
                for record in records_to_delete:
                    session.delete(record)
                session.commit()
                if not records_to_delete:
                    continue
                try:
                    doc = session.query(KBDocument).filter_by(doc_id=records_to_delete[0].doc_id).one()
                    doc.count = max(0, doc.count - 1)
                    session.commit()
                except NoResultFound:
                    lazyllm.LOG.warning(f"No document found for {doc_id}")

delete_unreferenced_doc()

删除数据库中标记为删除且未被任何知识库分组引用的文档记录。

该方法会查找状态为“deleting”且引用计数为0的文档,删除这些文档记录,并记录删除操作日志。

Source code in lazyllm/tools/rag/utils.py
    def delete_unreferenced_doc(self):
        """删除数据库中标记为删除且未被任何知识库分组引用的文档记录。

该方法会查找状态为“deleting”且引用计数为0的文档,删除这些文档记录,并记录删除操作日志。

"""
        with self._db_lock, self._Session() as session:
            docs_to_delete = (
                session.query(KBDocument)
                .filter(KBDocument.status == DocListManager.Status.deleting, KBDocument.count == 0)
                .all()
            )
            for doc in docs_to_delete:
                session.delete(doc)
                log = KBOperationLogs(log=f"Delete obsolete file, doc_id:{doc.doc_id}, path:{doc.path}.")
                session.add(log)
            session.commit()

fetch_docs_changed_meta(group)

获取指定知识库分组中元数据发生变化的文档列表,并将对应的 new_meta 字段清空。

Parameters:

  • group (str) –

    知识库分组名称。

Returns:

  • List[DocMetaChangedRow]: 包含文档ID及其对应新元数据的列表。
Source code in lazyllm/tools/rag/utils.py
    def fetch_docs_changed_meta(self, group: str) -> List[DocMetaChangedRow]:
        """获取指定知识库分组中元数据发生变化的文档列表,并将对应的 new_meta 字段清空。

Args:
    group (str): 知识库分组名称。

**Returns:**

- List[DocMetaChangedRow]: 包含文档ID及其对应新元数据的列表。
"""
        rows = []
        conds = [KBGroupDocuments.group_name == group, KBGroupDocuments.new_meta.isnot(None)]
        with self._db_lock, self._Session() as session:
            rows = (
                session.query(KBDocument.doc_id, KBGroupDocuments.new_meta)
                .join(KBGroupDocuments, KBDocument.doc_id == KBGroupDocuments.doc_id)
                .filter(*conds).all()
            )
            stmt = update(KBGroupDocuments).where(sqlalchemy.and_(*conds)).values(new_meta=None)
            session.execute(stmt)
            session.commit()
        return rows

get_docs(doc_ids)

根据给定的文档ID列表,从数据库中获取对应的文档对象列表。

Parameters:

  • doc_ids (List[str]) –

    需要查询的文档ID列表。

Returns:

  • List[KBDocument]: 匹配的文档对象列表。如果没有匹配项,返回空列表。
Source code in lazyllm/tools/rag/utils.py
    def get_docs(self, doc_ids: List[str]) -> List[KBDocument]:
        """根据给定的文档ID列表,从数据库中获取对应的文档对象列表。

Args:
    doc_ids (List[str]): 需要查询的文档ID列表。

**Returns:**

- List[KBDocument]: 匹配的文档对象列表。如果没有匹配项,返回空列表。
"""
        with self._db_lock, self._Session() as session:
            docs = session.query(KBDocument).filter(KBDocument.doc_id.in_(doc_ids)).all()
            return docs
        return []

get_docs_need_reparse(group)

获取指定知识库分组中需要重新解析的文档列表。

仅返回状态为“success”或“failed”的文档,且其对应的知识库分组记录标记为需要重新解析。

Parameters:

  • group (str) –

    知识库分组名称。

Returns:

  • List[KBDocument]: 需要重新解析的文档列表。
Source code in lazyllm/tools/rag/utils.py
    def get_docs_need_reparse(self, group: str) -> List[KBDocument]:
        """获取指定知识库分组中需要重新解析的文档列表。

仅返回状态为“success”或“failed”的文档,且其对应的知识库分组记录标记为需要重新解析。

Args:
    group (str): 知识库分组名称。

**Returns:**

- List[KBDocument]: 需要重新解析的文档列表。
"""
        with self._db_lock, self._Session() as session:
            filter_status_list = [DocListManager.Status.success, DocListManager.Status.failed]
            documents = (
                session.query(KBDocument).join(KBGroupDocuments, KBDocument.doc_id == KBGroupDocuments.doc_id)
                .filter(KBGroupDocuments.need_reparse.is_(True),
                        KBGroupDocuments.group_name == group,
                        KBGroupDocuments.status.in_(filter_status_list)).all())
            return documents
        return []

get_existing_paths_by_pattern(pattern)

根据路径匹配模式获取已存在的文档路径列表。

Parameters:

  • pattern (str) –

    路径匹配模式,支持SQL的LIKE通配符。

Returns:

  • List[str]: 匹配到的已存在文档路径列表。
Source code in lazyllm/tools/rag/utils.py
    def get_existing_paths_by_pattern(self, pattern: str) -> List[str]:
        """根据路径匹配模式获取已存在的文档路径列表。

Args:
    pattern (str): 路径匹配模式,支持SQL的LIKE通配符。

**Returns:**

- List[str]: 匹配到的已存在文档路径列表。
"""
        exist_paths = []
        with self._db_lock, self._Session() as session:
            docs = session.query(KBDocument).filter(KBDocument.path.like(pattern)).all()
            exist_paths = [doc.path for doc in docs]
        return exist_paths

get_file_status(fileid)

获取指定文件的状态。

Parameters:

  • fileid (str) –

    文件的唯一标识符。

Returns:

  • Optional[Tuple]: 返回包含状态的元组,若文件不存在则返回 None。
Source code in lazyllm/tools/rag/utils.py
    def get_file_status(self, fileid: str):
        """获取指定文件的状态。

Args:
    fileid (str): 文件的唯一标识符。

**Returns:**

- Optional[Tuple]: 返回包含状态的元组,若文件不存在则返回 None。
"""
        with self._db_lock, sqlite3.connect(self._db_path, check_same_thread=self._check_same_thread) as conn:
            cursor = conn.execute("SELECT status FROM documents WHERE doc_id = ?", (fileid,))
        return cursor.fetchone()

list_all_kb_group()

列出数据库中所有的知识库分组名称。

Returns:

  • List[str]: 知识库分组名称列表。
Source code in lazyllm/tools/rag/utils.py
    def list_all_kb_group(self):
        """列出数据库中所有的知识库分组名称。

**Returns:**

- List[str]: 知识库分组名称列表。
"""
        with self._db_lock, sqlite3.connect(self._db_path, check_same_thread=self._check_same_thread) as conn:
            cursor = conn.execute("SELECT group_name FROM document_groups")
            return [row[0] for row in cursor]

list_files(limit=None, details=False, status=DocListManager.Status.all, exclude_status=None)

列出文档数据库中符合状态条件的文件,并根据参数选择返回完整记录或仅返回文件路径。

Parameters:

  • limit (Optional[int], default: None ) –

    要返回的记录数上限,若为 None 则返回所有符合条件的记录。

  • details (bool, default: False ) –

    是否返回完整的数据库行信息,若为 False 则仅返回文档路径(ID)。

  • status (Union[str, List[str]], default: all ) –

    要包含在结果中的状态值,默认为包含所有状态。

  • exclude_status (Optional[Union[str, List[str]]], default: None ) –

    要从结果中排除的状态值。

Returns:

  • list: 文件记录列表或文档路径列表,具体取决于 details 参数。
Source code in lazyllm/tools/rag/utils.py
    def list_files(self, limit: Optional[int] = None, details: bool = False,
                   status: Union[str, List[str]] = DocListManager.Status.all,
                   exclude_status: Optional[Union[str, List[str]]] = None):
        """列出文档数据库中符合状态条件的文件,并根据参数选择返回完整记录或仅返回文件路径。

Args:
    limit (Optional[int]): 要返回的记录数上限,若为 None 则返回所有符合条件的记录。
    details (bool): 是否返回完整的数据库行信息,若为 False 则仅返回文档路径(ID)。
    status (Union[str, List[str]]): 要包含在结果中的状态值,默认为包含所有状态。
    exclude_status (Optional[Union[str, List[str]]]): 要从结果中排除的状态值。

**Returns:**

- list: 文件记录列表或文档路径列表,具体取决于 `details` 参数。
"""
        query = "SELECT * FROM documents"
        params = []
        status_cond, status_params = self.get_status_cond_and_params(status, exclude_status, prefix=None)
        if status_cond:
            query += f' WHERE {status_cond}'
            params.extend(status_params)
        if limit:
            query += " LIMIT ?"
            params.append(limit)
        with self._db_lock, sqlite3.connect(self._db_path, check_same_thread=self._check_same_thread) as conn:
            cursor = conn.execute(query, params)
            return cursor.fetchall() if details else [row[0] for row in cursor]

list_kb_group_files(group=None, limit=None, details=False, status=DocListManager.Status.all, exclude_status=None, upload_status=DocListManager.Status.all, exclude_upload_status=None, need_reparse=None)

列出指定知识库分组中的文件信息,可根据多种条件进行过滤。

Parameters:

  • group (str, default: None ) –

    知识库分组名称,若为 None 则不按分组过滤。

  • limit (int, default: None ) –

    限制返回的文件数量。

  • details (bool, default: False ) –

    是否返回详细的文件信息。

  • status (str or List[str], default: all ) –

    过滤知识库分组中文件的状态。

  • exclude_status (str or List[str], default: None ) –

    排除指定状态的文件。

  • upload_status (str or List[str], default: all ) –

    过滤文件上传状态。

  • exclude_upload_status (str or List[str], default: None ) –

    排除指定的上传状态。

  • need_reparse (bool, default: None ) –

    是否只返回需要重新解析的文件。

Returns:

  • list:
    • 如果 details 为 False,返回列表,每个元素为 (doc_id, path) 元组。
    • 如果 details 为 True,返回包含文件详细信息的元组列表,包括文档ID、路径、状态、元数据, 知识库分组名、分组内状态及日志。
Source code in lazyllm/tools/rag/utils.py
    def list_kb_group_files(self, group: str = None, limit: Optional[int] = None, details: bool = False,
                            status: Union[str, List[str]] = DocListManager.Status.all,
                            exclude_status: Optional[Union[str, List[str]]] = None,
                            upload_status: Union[str, List[str]] = DocListManager.Status.all,
                            exclude_upload_status: Optional[Union[str, List[str]]] = None,
                            need_reparse: Optional[bool] = None):
        """列出指定知识库分组中的文件信息,可根据多种条件进行过滤。

Args:
    group (str, optional): 知识库分组名称,若为 None 则不按分组过滤。
    limit (int, optional): 限制返回的文件数量。
    details (bool): 是否返回详细的文件信息。
    status (str or List[str], optional): 过滤知识库分组中文件的状态。
    exclude_status (str or List[str], optional): 排除指定状态的文件。
    upload_status (str or List[str], optional): 过滤文件上传状态。
    exclude_upload_status (str or List[str], optional): 排除指定的上传状态。
    need_reparse (bool, optional): 是否只返回需要重新解析的文件。

**Returns:**

- list: 
    - 如果 details 为 False,返回列表,每个元素为 (doc_id, path) 元组。
    - 如果 details 为 True,返回包含文件详细信息的元组列表,包括文档ID、路径、状态、元数据,
      知识库分组名、分组内状态及日志。
"""
        query = """
            SELECT documents.doc_id, documents.path, documents.status, documents.meta,
                   kb_group_documents.group_name, kb_group_documents.status, kb_group_documents.log
            FROM kb_group_documents
            JOIN documents ON kb_group_documents.doc_id = documents.doc_id
        """
        conds, params = [], []
        if group:
            conds.append('kb_group_documents.group_name = ?')
            params.append(group)

        if need_reparse is not None:
            conds.append('kb_group_documents.need_reparse = ?')
            params.append(int(need_reparse))

        status_cond, status_params = self.get_status_cond_and_params(status, exclude_status, prefix='kb_group_documents')
        if status_cond:
            conds.append(status_cond)
            params.extend(status_params)

        status_cond, status_params = self.get_status_cond_and_params(
            upload_status, exclude_upload_status, prefix='documents')
        if status_cond:
            conds.append(status_cond)
            params.extend(status_params)

        if conds: query += ' WHERE ' + ' AND '.join(conds)

        if limit:
            query += ' LIMIT ?'
            params.append(limit)

        with self._db_lock, sqlite3.connect(self._db_path, check_same_thread=self._check_same_thread) as conn:
            cursor = conn.execute(query, params)
            rows = cursor.fetchall()

        if not details: return [row[:2] for row in rows]
        return rows

release()

清空数据库中的所有文档、分组及相关操作日志数据。

该操作会删除 documents、document_groups、kb_group_documents 和 operation_logs 表中的所有记录。

Source code in lazyllm/tools/rag/utils.py
    def release(self):
        """清空数据库中的所有文档、分组及相关操作日志数据。

该操作会删除 documents、document_groups、kb_group_documents 和 operation_logs 表中的所有记录。

"""
        with self._db_lock, sqlite3.connect(self._db_path, check_same_thread=self._check_same_thread) as conn:
            conn.execute('delete from documents')
            conn.execute('delete from document_groups')
            conn.execute('delete from kb_group_documents')
            conn.execute('delete from operation_logs')
            conn.commit()

set_docs_new_meta(doc_meta)

批量更新文档的元数据(meta),同时更新对应知识库分组中文档的 new_meta 字段(非等待状态的文档)。

Parameters:

  • doc_meta (Dict[str, dict]) –

    字典,键为文档ID,值为对应的新元数据字典。

Source code in lazyllm/tools/rag/utils.py
    def set_docs_new_meta(self, doc_meta: Dict[str, dict]):
        """批量更新文档的元数据(meta),同时更新对应知识库分组中文档的 new_meta 字段(非等待状态的文档)。

Args:
    doc_meta (Dict[str, dict]): 字典,键为文档ID,值为对应的新元数据字典。
"""
        data_to_update = [{"_doc_id": k, "_meta": json.dumps(v)} for k, v in doc_meta.items()]
        with self._db_lock, self._Session() as session:
            # Use sqlalchemy core bulk update
            stmt = KBDocument.__table__.update().where(
                KBDocument.doc_id == bindparam("_doc_id")).values(meta=bindparam("_meta"))
            session.execute(stmt, data_to_update)
            session.commit()

            stmt = KBGroupDocuments.__table__.update().where(
                KBGroupDocuments.doc_id == bindparam("_doc_id"),
                KBGroupDocuments.status != DocListManager.Status.waiting).values(new_meta=bindparam("_meta"))
            session.execute(stmt, data_to_update)
            session.commit()

table_inited()

检查数据库中是否已存在名为 "documents" 的表。

该方法通过查询 sqlite_master 元信息表,判断数据表是否已初始化。

Returns:

  • bool: 如果 "documents" 表存在,返回 True;否则返回 False。
Source code in lazyllm/tools/rag/utils.py
    def table_inited(self):
        """检查数据库中是否已存在名为 "documents" 的表。

该方法通过查询 sqlite_master 元信息表,判断数据表是否已初始化。

**Returns:**

- bool: 如果 "documents" 表存在,返回 True;否则返回 False。
"""
        with self._db_lock, sqlite3.connect(self._db_path, check_same_thread=self._check_same_thread) as conn:
            cursor = conn.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='documents'")
            return cursor.fetchone() is not None

update_file_message(fileid, **kw)

更新指定文件的字段信息。

Parameters:

  • fileid (str) –

    文件的唯一标识符(doc_id)。

  • **kw

    需要更新的字段及其对应的值,键值对形式传入。

Source code in lazyllm/tools/rag/utils.py
    def update_file_message(self, fileid: str, **kw):
        """更新指定文件的字段信息。

Args:
    fileid (str): 文件的唯一标识符(doc_id)。
    **kw: 需要更新的字段及其对应的值,键值对形式传入。
"""
        set_clause = ", ".join([f"{k} = ?" for k in kw.keys()])
        params = list(kw.values()) + [fileid]
        with self._db_lock, sqlite3.connect(self._db_path, check_same_thread=self._check_same_thread) as conn:
            conn.execute(f"UPDATE documents SET {set_clause} WHERE doc_id = ?", params)
            conn.commit()

update_file_status(file_ids, status, cond_status_list=None)

更新多个文件的状态,支持根据当前状态进行条件过滤。

Parameters:

  • file_ids (List[str]) –

    需要更新状态的文件ID列表。

  • status (str) –

    要设置的新状态。

  • cond_status_list (Union[None, List[str]], default: None ) –

    仅更新当前状态在此列表中的文件,默认为 None,表示不筛选。

Returns:

  • List[DocPartRow]: 返回更新后的文件ID和路径列表。
Source code in lazyllm/tools/rag/utils.py
    def update_file_status(self, file_ids: List[str], status: str,
                           cond_status_list: Union[None, List[str]] = None) -> List[DocPartRow]:
        """更新多个文件的状态,支持根据当前状态进行条件过滤。

Args:
    file_ids (List[str]): 需要更新状态的文件ID列表。
    status (str): 要设置的新状态。
    cond_status_list (Union[None, List[str]], optional): 仅更新当前状态在此列表中的文件,默认为 None,表示不筛选。

**Returns:**

- List[DocPartRow]: 返回更新后的文件ID和路径列表。
"""
        rows = []
        if cond_status_list is None:
            sql_cond = KBDocument.doc_id.in_(file_ids)
        else:
            sql_cond = sqlalchemy.and_(KBDocument.status.in_(cond_status_list), KBDocument.doc_id.in_(file_ids))
        with self._db_lock, self._Session() as session:
            stmt = (
                update(KBDocument)
                .where(sql_cond)
                .values(status=status)
                .returning(KBDocument.doc_id, KBDocument.path)
            )
            rows = session.execute(stmt).fetchall()
            session.commit()
        return rows

update_kb_group(cond_file_ids, cond_group=None, cond_status_list=None, new_status=None, new_need_reparse=None)

更新知识库分组中指定文件的状态和重解析需求。

根据给定的文件ID列表、分组名及状态列表,批量更新对应文件在知识库分组中的状态及是否需要重解析标志。

Parameters:

  • cond_file_ids (List[str]) –

    需要更新的文件ID列表。

  • cond_group (Optional[str], default: None ) –

    分组名称,若指定则只更新该分组内的文件。

  • cond_status_list (Optional[List[str]], default: None ) –

    仅更新状态匹配此列表的文件。

  • new_status (Optional[str], default: None ) –

    新的文件状态。

  • new_need_reparse (Optional[bool], default: None ) –

    新的重解析需求标志。

Returns:

  • List[Tuple]: 返回更新后文件的doc_id、group_name及状态列表。
Source code in lazyllm/tools/rag/utils.py
    def update_kb_group(self, cond_file_ids: List[str], cond_group: Optional[str] = None,
                        cond_status_list: Optional[List[str]] = None, new_status: Optional[str] = None,
                        new_need_reparse: Optional[bool] = None) -> List[GroupDocPartRow]:
        """更新知识库分组中指定文件的状态和重解析需求。

根据给定的文件ID列表、分组名及状态列表,批量更新对应文件在知识库分组中的状态及是否需要重解析标志。

Args:
    cond_file_ids (List[str]): 需要更新的文件ID列表。
    cond_group (Optional[str]): 分组名称,若指定则只更新该分组内的文件。
    cond_status_list (Optional[List[str]]): 仅更新状态匹配此列表的文件。
    new_status (Optional[str]): 新的文件状态。
    new_need_reparse (Optional[bool]): 新的重解析需求标志。

**Returns:**

- List[Tuple]: 返回更新后文件的doc_id、group_name及状态列表。
"""
        rows = []
        conds = []
        if not cond_file_ids:
            return rows
        conds.append(KBGroupDocuments.doc_id.in_(cond_file_ids))
        if cond_group is not None:
            conds.append(KBGroupDocuments.group_name == cond_group)
        if cond_status_list:
            conds.append(KBGroupDocuments.status.in_(cond_status_list))

        vals = {}
        if new_status is not None:
            vals[KBGroupDocuments.status.name] = new_status
        if new_need_reparse is not None:
            vals[KBGroupDocuments.need_reparse.name] = new_need_reparse

        if not vals:
            return rows
        with self._db_lock, self._Session() as session:
            stmt = (
                update(KBGroupDocuments)
                .where(sqlalchemy.and_(*conds))
                .values(vals)
                .returning(KBGroupDocuments.doc_id, KBGroupDocuments.group_name, KBGroupDocuments.status)
            )
            rows = session.execute(stmt).fetchall()
            session.commit()
        return rows

update_need_reparsing(doc_id, need_reparse, group_name=None)

更新指定文档的重解析标志位。

该方法用于设置某个文档是否需要重新解析。可以选择性地指定知识库分组进行精确匹配。

Parameters:

  • doc_id (str) –

    文档的唯一标识符。

  • need_reparse (bool) –

    是否需要重新解析文档。

  • group_name (Optional[str], default: None ) –

    可选,所属的知识库分组名称。如果提供,将仅更新指定分组中的文档。

Source code in lazyllm/tools/rag/utils.py
    def update_need_reparsing(self, doc_id: str, need_reparse: bool, group_name: Optional[str] = None):
        """更新指定文档的重解析标志位。

该方法用于设置某个文档是否需要重新解析。可以选择性地指定知识库分组进行精确匹配。

Args:
    doc_id (str): 文档的唯一标识符。
    need_reparse (bool): 是否需要重新解析文档。
    group_name (Optional[str]): 可选,所属的知识库分组名称。如果提供,将仅更新指定分组中的文档。
"""
        with self._db_lock, self._Session() as session:
            stmt = update(KBGroupDocuments).where(KBGroupDocuments.doc_id == doc_id)
            if group_name is not None: stmt = stmt.where(KBGroupDocuments.group_name == group_name)
            session.execute(stmt.values(need_reparse=need_reparse))
            session.commit()

validate_paths(paths)

验证输入路径所对应的文档是否可以安全添加到数据库。

该方法会检查每个路径是否对应已有文档,若已存在,需判断其状态是否允许重解析。 若文档正在解析或等待解析,或上次重解析未完成,则视为不可用。

Parameters:

  • paths (List[str]) –

    文件路径列表。

Returns:

  • Tuple[bool, str, List[bool]]:
    • bool: 是否所有路径都验证通过。
    • str: 成功或失败的描述信息。
    • List[bool]: 与输入路径一一对应的布尔列表,表示该路径是否为新文档(True 为新文档,False 为已存在)。 若验证失败,返回值为 None。
Source code in lazyllm/tools/rag/utils.py
    def validate_paths(self, paths: List[str]) -> Tuple[bool, str, List[bool]]:
        """验证输入路径所对应的文档是否可以安全添加到数据库。

该方法会检查每个路径是否对应已有文档,若已存在,需判断其状态是否允许重解析。
若文档正在解析或等待解析,或上次重解析未完成,则视为不可用。

Args:
    paths (List[str]): 文件路径列表。

**Returns:**

- Tuple[bool, str, List[bool]]: 
    - bool: 是否所有路径都验证通过。
    - str: 成功或失败的描述信息。
    - List[bool]: 与输入路径一一对应的布尔列表,表示该路径是否为新文档(True 为新文档,False 为已存在)。
        若验证失败,返回值为 None。
"""
        # check and return: success, msg, path_is_new for each path
        unsafe_staus_set = set([DocListManager.Status.working, DocListManager.Status.waiting])
        paths_is_new = [True] * len(paths)
        doc_ids = [gen_docid(path) for path in paths]
        doc_id_to_path = {doc_id: path for doc_id, path in zip(doc_ids, paths)}
        found_doc_ids = []
        found_doc_group_rows = []
        with self._db_lock, self._Session() as session:
            rows = session.execute(
                select(KBDocument.doc_id).where(KBDocument.doc_id.in_(doc_ids))
            ).fetchall()
            if len(rows) == 0:
                return True, "Success", paths_is_new
            found_doc_ids = [row.doc_id for row in rows]
            found_doc_group_rows = session.execute(
                select(KBGroupDocuments.doc_id, KBGroupDocuments.need_reparse, KBGroupDocuments.status)
                .where(KBGroupDocuments.doc_id.in_(found_doc_ids))).fetchall()

        for doc_group_record in found_doc_group_rows:
            if doc_group_record.need_reparse:
                msg = f"Failed: {doc_id_to_path[doc_group_record.doc_id]} lasttime reparsing has not been finished"
                return False, msg, None
            if doc_group_record.status in unsafe_staus_set:
                return False, f"Failed: {doc_id_to_path[doc_group_record.doc_id]} is being parsed by kbgroup", None
        found_doc_ids = set(found_doc_ids)
        for i in range(len(paths)):
            cur_doc_id = doc_ids[i]
            if cur_doc_id in found_doc_ids:
                paths_is_new[i] = False
        return True, "Success", paths_is_new

lazyllm.tools.rag.data_loaders.DirectoryReader

用于从文件目录加载和处理文档的目录读取器类。

此类提供从指定目录读取文档并将其转换为文档节点的功能。它支持本地和全局文件读取器,并且可以处理不同类型的文档,包括图像。

Parameters:

  • input_files (Optional[List[str]]) –

    要读取的文件路径列表。如果为None,文件将在调用load_data方法时加载。

  • local_readers (Optional[Dict], default: None ) –

    特定于此实例的本地文件读取器字典。键是文件模式,值是读取器函数。

  • global_readers (Optional[Dict], default: None ) –

    在所有实例间共享的全局文件读取器字典。键是文件模式,值是读取器函数。

Examples:

>>> from lazyllm.tools.rag.data_loaders import DirectoryReader
>>> from lazyllm.tools.rag.readers import DocxReader, PDFReader
>>> local_readers = {
...     "**/*.docx": DocxReader,
...     "**/*.pdf": PDFReader
>>> }
>>> reader = DirectoryReader(
...     input_files=["path/to/documents"],
...     local_readers=local_readers,
...     global_readers={}
>>> )
>>> documents = reader.load_data()
>>> print(f"加载了 {len(documents)} 个文档")
Source code in lazyllm/tools/rag/data_loaders.py
class DirectoryReader:
    """用于从文件目录加载和处理文档的目录读取器类。

此类提供从指定目录读取文档并将其转换为文档节点的功能。它支持本地和全局文件读取器,并且可以处理不同类型的文档,包括图像。

Args:
    input_files (Optional[List[str]]): 要读取的文件路径列表。如果为None,文件将在调用load_data方法时加载。
    local_readers (Optional[Dict]): 特定于此实例的本地文件读取器字典。键是文件模式,值是读取器函数。
    global_readers (Optional[Dict]): 在所有实例间共享的全局文件读取器字典。键是文件模式,值是读取器函数。


Examples:
    >>> from lazyllm.tools.rag.data_loaders import DirectoryReader
    >>> from lazyllm.tools.rag.readers import DocxReader, PDFReader
    >>> local_readers = {
    ...     "**/*.docx": DocxReader,
    ...     "**/*.pdf": PDFReader
    >>> }
    >>> reader = DirectoryReader(
    ...     input_files=["path/to/documents"],
    ...     local_readers=local_readers,
    ...     global_readers={}
    >>> )
    >>> documents = reader.load_data()
    >>> print(f"加载了 {len(documents)} 个文档")
    """
    def __init__(self, input_files: Optional[List[str]], local_readers: Optional[Dict] = None,
                 global_readers: Optional[Dict] = None) -> None:
        self._input_files = input_files
        self._local_readers = local_readers
        self._global_readers = global_readers

    def load_data(self, input_files: Optional[List[str]] = None, metadatas: Optional[Dict] = None,
                  *, split_image_nodes: bool = False) -> List[DocNode]:
        """从指定的输入文件加载和处理文档。

此方法使用配置的文件读取器(本地和全局)从输入文件读取文档,将它们处理成文档节点,并可选地将图像节点与文本节点分离。

Args:
    input_files (Optional[List[str]]): 要读取的文件路径列表。如果为None,使用初始化时指定的文件。
    metadatas (Optional[Dict]): 与加载文档关联的额外元数据。
    split_image_nodes (bool): 是否将图像节点与文本节点分离。如果为True,返回(text_nodes, image_nodes)的元组。如果为False,一起返回所有节点。

**Returns:**
- Union[List[DocNode], Tuple[List[DocNode], List[ImageDocNode]]]: 如果split_image_nodes为False,返回所有文档节点的列表。如果为True,返回包含文本节点和图像节点的元组。
"""
        input_files = input_files or self._input_files
        file_readers = self._local_readers.copy()
        for key, func in self._global_readers.items():
            if key not in file_readers: file_readers[key] = func
        LOG.info(f"DirectoryReader loads data, input files: {input_files}")
        reader = SimpleDirectoryReader(input_files=input_files, file_extractor=file_readers, metadatas=metadatas)
        nodes: List[DocNode] = []
        image_nodes: List[ImageDocNode] = []
        for doc in reader():
            doc._group = LAZY_IMAGE_GROUP if isinstance(doc, ImageDocNode) else LAZY_ROOT_NAME
            if not split_image_nodes or not isinstance(doc, ImageDocNode):
                nodes.append(doc)
            else:
                image_nodes.append(doc)
        if not nodes and not image_nodes:
            LOG.warning(
                f"No nodes load from path {input_files}, please check your data path."
            )
        LOG.info("DirectoryReader loads data done!")
        return (nodes, image_nodes) if split_image_nodes else nodes

load_data(input_files=None, metadatas=None, *, split_image_nodes=False)

从指定的输入文件加载和处理文档。

此方法使用配置的文件读取器(本地和全局)从输入文件读取文档,将它们处理成文档节点,并可选地将图像节点与文本节点分离。

Parameters:

  • input_files (Optional[List[str]], default: None ) –

    要读取的文件路径列表。如果为None,使用初始化时指定的文件。

  • metadatas (Optional[Dict], default: None ) –

    与加载文档关联的额外元数据。

  • split_image_nodes (bool, default: False ) –

    是否将图像节点与文本节点分离。如果为True,返回(text_nodes, image_nodes)的元组。如果为False,一起返回所有节点。

Returns: - Union[List[DocNode], Tuple[List[DocNode], List[ImageDocNode]]]: 如果split_image_nodes为False,返回所有文档节点的列表。如果为True,返回包含文本节点和图像节点的元组。

Source code in lazyllm/tools/rag/data_loaders.py
    def load_data(self, input_files: Optional[List[str]] = None, metadatas: Optional[Dict] = None,
                  *, split_image_nodes: bool = False) -> List[DocNode]:
        """从指定的输入文件加载和处理文档。

此方法使用配置的文件读取器(本地和全局)从输入文件读取文档,将它们处理成文档节点,并可选地将图像节点与文本节点分离。

Args:
    input_files (Optional[List[str]]): 要读取的文件路径列表。如果为None,使用初始化时指定的文件。
    metadatas (Optional[Dict]): 与加载文档关联的额外元数据。
    split_image_nodes (bool): 是否将图像节点与文本节点分离。如果为True,返回(text_nodes, image_nodes)的元组。如果为False,一起返回所有节点。

**Returns:**
- Union[List[DocNode], Tuple[List[DocNode], List[ImageDocNode]]]: 如果split_image_nodes为False,返回所有文档节点的列表。如果为True,返回包含文本节点和图像节点的元组。
"""
        input_files = input_files or self._input_files
        file_readers = self._local_readers.copy()
        for key, func in self._global_readers.items():
            if key not in file_readers: file_readers[key] = func
        LOG.info(f"DirectoryReader loads data, input files: {input_files}")
        reader = SimpleDirectoryReader(input_files=input_files, file_extractor=file_readers, metadatas=metadatas)
        nodes: List[DocNode] = []
        image_nodes: List[ImageDocNode] = []
        for doc in reader():
            doc._group = LAZY_IMAGE_GROUP if isinstance(doc, ImageDocNode) else LAZY_ROOT_NAME
            if not split_image_nodes or not isinstance(doc, ImageDocNode):
                nodes.append(doc)
            else:
                image_nodes.append(doc)
        if not nodes and not image_nodes:
            LOG.warning(
                f"No nodes load from path {input_files}, please check your data path."
            )
        LOG.info("DirectoryReader loads data done!")
        return (nodes, image_nodes) if split_image_nodes else nodes

lazyllm.tools.SentenceSplitter

Bases: NodeTransform

将句子拆分成指定大小的块。可以指定相邻块之间重合部分的大小。

Parameters:

  • chunk_size (int, default: 1024 ) –

    拆分之后的块大小

  • chunk_overlap (int, default: 200 ) –

    相邻两个块之间重合的内容长度

  • num_workers(int)

    控制并行处理的线程/进程数量

Examples:

>>> import lazyllm
>>> from lazyllm.tools import Document, SentenceSplitter
>>> m = lazyllm.OnlineEmbeddingModule(source="glm")
>>> documents = Document(dataset_path='your_doc_path', embed=m, manager=False)
>>> documents.create_node_group(name="sentences", transform=SentenceSplitter, chunk_size=1024, chunk_overlap=100)
Source code in lazyllm/tools/rag/transform.py
class SentenceSplitter(NodeTransform):
    """
将句子拆分成指定大小的块。可以指定相邻块之间重合部分的大小。

Args:
    chunk_size (int): 拆分之后的块大小
    chunk_overlap (int): 相邻两个块之间重合的内容长度
    num_workers(int):控制并行处理的线程/进程数量


Examples:

    >>> import lazyllm
    >>> from lazyllm.tools import Document, SentenceSplitter
    >>> m = lazyllm.OnlineEmbeddingModule(source="glm")
    >>> documents = Document(dataset_path='your_doc_path', embed=m, manager=False)
    >>> documents.create_node_group(name="sentences", transform=SentenceSplitter, chunk_size=1024, chunk_overlap=100)
    """
    def __init__(self, chunk_size: int = 1024, chunk_overlap: int = 200, num_workers: int = 0):
        super(__class__, self).__init__(num_workers=num_workers)
        if chunk_overlap > chunk_size:
            raise ValueError(
                f'Got a larger chunk overlap ({chunk_overlap}) than chunk size '
                f'({chunk_size}), should be smaller.'
            )

        assert (
            chunk_size > 0 and chunk_overlap >= 0
        ), 'chunk size should > 0 and chunk_overlap should >= 0'

        try:
            if 'TIKTOKEN_CACHE_DIR' not in os.environ and 'DATA_GYM_CACHE_DIR' not in os.environ:
                path = os.path.join(config['model_path'], 'tiktoken')
                os.makedirs(path, exist_ok=True)
                os.environ['TIKTOKEN_CACHE_DIR'] = path
            self._tiktoken_tokenizer = tiktoken.encoding_for_model('gpt-3.5-turbo')
            os.environ.pop('TIKTOKEN_CACHE_DIR')
        except requests.exceptions.ConnectionError:
            LOG.error(
                'Unable to download the vocabulary file for tiktoken `gpt-3.5-turbo`. '
                'Please check your internet connection. '
                'Alternatively, you can manually download the file '
                'and set the `TIKTOKEN_CACHE_DIR` environment variable.'
            )
            raise
        except Exception as e:
            LOG.error(f'Unable to build tiktoken tokenizer with error `{e}`')
            raise
        self._punkt_st_tokenizer = nltk.tokenize.PunktSentenceTokenizer()

        self._sentence_split_fns = [
            partial(split_text_keep_separator, separator='\n\n\n'),  # paragraph
            self._punkt_st_tokenizer.tokenize,
        ]

        self._sub_sentence_split_fns = [
            lambda t: re.findall(r'[^,.;。?!]+[,.;。?!]?', t),
            partial(split_text_keep_separator, separator=' '),
            list,  # split by character
        ]

        self.chunk_size = chunk_size
        self.chunk_overlap = chunk_overlap

    def transform(self, node: DocNode, **kwargs) -> List[str]:
        return self.split_text(
            node.get_text(),
            metadata_size=self._get_metadata_size(node),
        )

    def _get_metadata_size(self, node: DocNode) -> int:
        # Return the bigger size to ensure chunk_size < limit
        return max(
            self._token_size(node.get_metadata_str(mode=MetadataMode.EMBED)),
            self._token_size(node.get_metadata_str(mode=MetadataMode.LLM)),
        )

    def split_text(self, text: str, metadata_size: int) -> List[str]:
        if text == '':
            return ['']
        effective_chunk_size = self.chunk_size - metadata_size
        if effective_chunk_size <= 0:
            raise ValueError(
                f'Metadata length ({metadata_size}) is longer than chunk size '
                f'({self.chunk_size}). Consider increasing the chunk size or '
                'decreasing the size of your metadata to avoid this.'
            )
        elif effective_chunk_size < 50:
            LOG.warning(
                f'Metadata length ({metadata_size}) is close to chunk size '
                f'({self.chunk_size}). Resulting chunks are less than 50 tokens. '
                'Consider increasing the chunk size or decreasing the size of '
                'your metadata to avoid this.'
            )

        splits = self._split(text, effective_chunk_size)
        chunks = self._merge(splits, effective_chunk_size)
        return chunks

    def _split(self, text: str, chunk_size: int) -> List[_Split]:
        """Break text into splits that are smaller than chunk size.

        The order of splitting is:
        1. split by paragraph separator
        2. split by chunking tokenizer
        3. split by second chunking regex
        4. split by default separator (' ')
        5. split by character
        """
        token_size = self._token_size(text)
        if token_size <= chunk_size:
            return [_Split(text, is_sentence=True, token_size=token_size)]

        text_splits_by_fns, is_sentence = self._get_splits_by_fns(text)

        text_splits = []
        for text in text_splits_by_fns:
            token_size = self._token_size(text)
            if token_size <= chunk_size:
                text_splits.append(
                    _Split(
                        text,
                        is_sentence=is_sentence,
                        token_size=token_size,
                    )
                )
            else:
                recursive_text_splits = self._split(text, chunk_size=chunk_size)
                text_splits.extend(recursive_text_splits)
        return text_splits

    def _merge(self, splits: List[_Split], chunk_size: int) -> List[str]:
        chunks: List[str] = []
        cur_chunk: List[Tuple[str, int]] = []  # list of (text, length)
        cur_chunk_len = 0
        is_chunk_new = True

        def close_chunk() -> None:
            nonlocal cur_chunk, cur_chunk_len, is_chunk_new

            chunks.append(''.join([text for text, _ in cur_chunk]))
            last_chunk = cur_chunk
            cur_chunk = []
            cur_chunk_len = 0
            is_chunk_new = True

            # Add overlap to the next chunk using the last one first
            overlap_len = 0
            for text, length in reversed(last_chunk):
                if overlap_len + length > self.chunk_overlap:
                    break
                cur_chunk.append((text, length))
                overlap_len += length
                cur_chunk_len += length
            cur_chunk.reverse()

        i = 0
        while i < len(splits):
            cur_split = splits[i]
            if cur_split.token_size > chunk_size:
                raise ValueError('Single token exceeded chunk size')
            if cur_chunk_len + cur_split.token_size > chunk_size and not is_chunk_new:
                # if adding split to current chunk exceeds chunk size
                close_chunk()
            else:
                if (
                    cur_split.is_sentence
                    or cur_chunk_len + cur_split.token_size <= chunk_size
                    or is_chunk_new  # new chunk, always add at least one split
                ):
                    # add split to chunk
                    cur_chunk_len += cur_split.token_size
                    cur_chunk.append((cur_split.text, cur_split.token_size))
                    i += 1
                    is_chunk_new = False
                else:
                    close_chunk()

        # handle the last chunk
        if not is_chunk_new:
            chunks.append(''.join([text for text, _ in cur_chunk]))

        # Remove whitespace only chunks and remove leading and trailing whitespace.
        return [stripped_chunk for chunk in chunks if (stripped_chunk := chunk.strip())]

    def _token_size(self, text: str) -> int:
        return len(self._tiktoken_tokenizer.encode(text, allowed_special='all'))

    def _get_splits_by_fns(self, text: str) -> Tuple[List[str], bool]:
        for split_fn in self._sentence_split_fns:
            splits = split_fn(text)
            if len(splits) > 1:
                return splits, True

        for split_fn in self._sub_sentence_split_fns:
            splits = split_fn(text)
            if len(splits) > 1:
                break

        return splits, False

lazyllm.tools.LLMParser

Bases: NodeTransform

一个文本摘要和关键词提取器,负责分析用户输入的文本,并根据请求任务提供简洁的摘要或提取相关关键词。

Parameters:

  • llm (TrainableModule) –

    可训练的模块

  • language (str) –

    语言种类,目前只支持中文(zh)和英文(en)

  • task_type (str) –

    目前支持两种任务:摘要(summary)和关键词抽取(keywords)。

  • num_workers(int)

    控制并行处理的线程/进程数量。

Examples:

>>> from lazyllm import TrainableModule
>>> from lazyllm.tools.rag import LLMParser
>>> llm = TrainableModule("internlm2-chat-7b")
>>> summary_parser = LLMParser(llm, language="en", task_type="summary")
Source code in lazyllm/tools/rag/transform.py
class LLMParser(NodeTransform):
    """
一个文本摘要和关键词提取器,负责分析用户输入的文本,并根据请求任务提供简洁的摘要或提取相关关键词。

Args:
    llm (TrainableModule): 可训练的模块
    language (str): 语言种类,目前只支持中文(zh)和英文(en)
    task_type (str): 目前支持两种任务:摘要(summary)和关键词抽取(keywords)。
    num_workers(int):控制并行处理的线程/进程数量。


Examples:

    >>> from lazyllm import TrainableModule
    >>> from lazyllm.tools.rag import LLMParser
    >>> llm = TrainableModule("internlm2-chat-7b")
    >>> summary_parser = LLMParser(llm, language="en", task_type="summary")
    """
    def __init__(self, llm: TrainableModule, language: str, task_type: str, num_workers: int = 30):
        super(__class__, self).__init__(num_workers=num_workers)
        assert language in ['en', 'zh'], f'Not supported language {language}'
        assert task_type in ['summary', 'keywords', 'qa', 'qa_img'], f'Not supported task_type {task_type}'
        self._task_type = task_type
        if self._task_type == 'qa_img':
            prompt = dict(system=templates[language][task_type], user='{input}')
        else:
            prompt = dict(system=templates[language][task_type], user='#input:\n{input}\n#output:\n')
        self._llm = llm.share(prompt=AlpacaPrompter(prompt), stream=False, format=self._format)
        self._task_type = task_type

    def transform(self, node: DocNode, **kwargs) -> List[str]:
        """
在指定的文档上执行设定的任务。

Args:
    node (DocNode): 需要执行抽取任务的文档。


Examples:

    >>> import lazyllm
    >>> from lazyllm.tools import LLMParser
    >>> llm = lazyllm.TrainableModule("internlm2-chat-7b").start()
    >>> m = lazyllm.TrainableModule("bge-large-zh-v1.5").start()
    >>> summary_parser = LLMParser(llm, language="en", task_type="summary")
    >>> keywords_parser = LLMParser(llm, language="en", task_type="keywords")
    >>> documents = lazyllm.Document(dataset_path="/path/to/your/data", embed=m, manager=False)
    >>> rm = lazyllm.Retriever(documents, group_name='CoarseChunk', similarity='bm25', topk=6)
    >>> doc_nodes = rm("test")
    >>> summary_result = summary_parser.transform(doc_nodes[0])
    >>> keywords_result = keywords_parser.transform(doc_nodes[0])
    """
        if self._task_type == 'qa_img':
            inputs = encode_query_with_filepaths('Extract QA pairs from images.', [node.image_path])
        else:
            inputs = node.get_text()
        result = self._llm(inputs)
        return [result] if isinstance(result, str) else result

    def _format(self, input):
        if self._task_type == 'keywords':
            return [s.strip() for s in input.split(',')]
        elif self._task_type in ('qa', 'qa_img'):
            return [QADocNode(query=q.strip()[3:].strip(), answer=a.strip()[3:].strip()) for q, a in zip(
                list(filter(None, map(str.strip, input.split("\n"))))[::2],
                list(filter(None, map(str.strip, input.split("\n"))))[1::2])]
        return input

transform(node, **kwargs)

在指定的文档上执行设定的任务。

Parameters:

  • node (DocNode) –

    需要执行抽取任务的文档。

Examples:

>>> import lazyllm
>>> from lazyllm.tools import LLMParser
>>> llm = lazyllm.TrainableModule("internlm2-chat-7b").start()
>>> m = lazyllm.TrainableModule("bge-large-zh-v1.5").start()
>>> summary_parser = LLMParser(llm, language="en", task_type="summary")
>>> keywords_parser = LLMParser(llm, language="en", task_type="keywords")
>>> documents = lazyllm.Document(dataset_path="/path/to/your/data", embed=m, manager=False)
>>> rm = lazyllm.Retriever(documents, group_name='CoarseChunk', similarity='bm25', topk=6)
>>> doc_nodes = rm("test")
>>> summary_result = summary_parser.transform(doc_nodes[0])
>>> keywords_result = keywords_parser.transform(doc_nodes[0])
Source code in lazyllm/tools/rag/transform.py
    def transform(self, node: DocNode, **kwargs) -> List[str]:
        """
在指定的文档上执行设定的任务。

Args:
    node (DocNode): 需要执行抽取任务的文档。


Examples:

    >>> import lazyllm
    >>> from lazyllm.tools import LLMParser
    >>> llm = lazyllm.TrainableModule("internlm2-chat-7b").start()
    >>> m = lazyllm.TrainableModule("bge-large-zh-v1.5").start()
    >>> summary_parser = LLMParser(llm, language="en", task_type="summary")
    >>> keywords_parser = LLMParser(llm, language="en", task_type="keywords")
    >>> documents = lazyllm.Document(dataset_path="/path/to/your/data", embed=m, manager=False)
    >>> rm = lazyllm.Retriever(documents, group_name='CoarseChunk', similarity='bm25', topk=6)
    >>> doc_nodes = rm("test")
    >>> summary_result = summary_parser.transform(doc_nodes[0])
    >>> keywords_result = keywords_parser.transform(doc_nodes[0])
    """
        if self._task_type == 'qa_img':
            inputs = encode_query_with_filepaths('Extract QA pairs from images.', [node.image_path])
        else:
            inputs = node.get_text()
        result = self._llm(inputs)
        return [result] if isinstance(result, str) else result

lazyllm.tools.rag.transform.NodeTransform members: exclude-members:

lazyllm.tools.rag.transform.TransformArgs dataclass

文档转换参数容器,用于统一管理文档处理中的各类配置参数。 Args: f(Union[str, Callable]):转换函数或注册的函数名。 trans_node(bool):是否转换节点类型。 num_workers(int):控制是否启用多线程(>0 时启用)。 kwargs(Dict):传递给转换函数的额外参数。 pattern(Union[str, Callable[[str], bool]]):文件名/内容匹配模式。

Examples:

>>> from lazyllm.tools import TransformArgs
>>> args = TransformArgs(f=lambda text: text.lower(),num_workers=4,pattern=r'.*\.md$')
>>>config = {'f': 'parse_pdf','kwargs': {'engine': 'pdfminer'},'trans_node': True}
>>>args = TransformArgs.from_dict(config)
print(args['f'])
print(args.get('unknown'))
Source code in lazyllm/tools/rag/transform.py
@dataclass
class TransformArgs():
    """
文档转换参数容器,用于统一管理文档处理中的各类配置参数。
Args:
    f(Union[str, Callable]):转换函数或注册的函数名。
    trans_node(bool):是否转换节点类型。
    num_workers(int):控制是否启用多线程(>0 时启用)。
    kwargs(Dict):传递给转换函数的额外参数。
    pattern(Union[str, Callable[[str], bool]]):文件名/内容匹配模式。


Examples:

    >>> from lazyllm.tools import TransformArgs
    >>> args = TransformArgs(f=lambda text: text.lower(),num_workers=4,pattern=r'.*\.md$')
    >>>config = {'f': 'parse_pdf','kwargs': {'engine': 'pdfminer'},'trans_node': True}
    >>>args = TransformArgs.from_dict(config)
    print(args['f'])
    print(args.get('unknown'))
    """
    f: Union[str, Callable]
    trans_node: Optional[bool] = None
    num_workers: int = 0
    kwargs: Dict = field(default_factory=dict)
    pattern: Optional[Union[str, Callable[[str], bool]]] = None

    @staticmethod
    def from_dict(d):
        return TransformArgs(f=d['f'], trans_node=d.get('trans_node'), num_workers=d.get(
            'num_workers', 0), kwargs=d.get('kwargs', dict()), pattern=d.get('pattern'))

    def __getitem__(self, key):
        if key in self.__dict__: return getattr(self, key)
        raise KeyError(f'Key {key} is not found in transform args')

    def get(self, key):
        if key in self.__dict__: return getattr(self, key)
        return None

lazyllm.tools.rag.similarity.register_similarity(func=None, mode=None, descend=True, batch=False)

相似度计算注册装饰器,用于统一注册和管理不同类型的相似度计算方法。 Args: func(Callable):相似度计算函数名。 mode(Literal['text', 'embedding']):text为文本直接匹配,embedding为向量相似度计算。 descend(bool):控制是否启用多线程(>0 时启用)。 kwargs(Dict):结果是否按相似度降序排列。 batch(bool):是否批量处理节点。

Source code in lazyllm/tools/rag/similarity.py
def register_similarity(
    func: Optional[Callable] = None,
    mode: Optional[Literal['text', 'embedding']] = None,
    descend: bool = True,
    batch: bool = False,
) -> Callable:
    """
相似度计算注册装饰器,用于统一注册和管理不同类型的相似度计算方法。
Args:
    func(Callable):相似度计算函数名。
    mode(Literal['text', 'embedding']):text为文本直接匹配,embedding为向量相似度计算。
    descend(bool):控制是否启用多线程(>0 时启用)。
    kwargs(Dict):结果是否按相似度降序排列。
    batch(bool):是否批量处理节点。
"""
    def decorator(f):
        @functools.wraps(f)
        def wrapper(query, nodes, **kwargs):
            if mode != "embedding":
                if batch:
                    return f(query, nodes, **kwargs)
                else:
                    return [(node, f(query, node, **kwargs)) for node in nodes]
            else:
                assert isinstance(query, dict), "query must be of dict type, used for similarity calculation."
                similarity = {}
                if batch:
                    for key, val in query.items():
                        nodes_embed = [node.embedding[key] for node in nodes]
                        similarity[key] = [(node, sim) for node, sim in zip(nodes, f(val, nodes_embed, **kwargs))]
                else:
                    for key, val in query.items():
                        similarity[key] = [(node, f(val, node.embedding[key], **kwargs)) for node in nodes]
                return similarity
        registered_similarities[f.__name__] = (wrapper, mode, descend)
        return wrapper

    return decorator(func) if func else decorator

lazyllm.tools.rag.doc_node.DocNode

在指定的文档上执行设定的任务。 Args: uid(str): 唯一标识符。 content(Union[str, List[Any]]):节点内容 group(str):文档组名 embedding(Dict[str, List[float]]):嵌入向量字典 parent(Union[str, "DocNode"]):父节点引用 store:存储表示 node_groups(Dict[str, Dict]):节点存储组 metadata(Dict[str, Any]):节点级元数据 global_metadata(Dict[str, Any]):文档级元数据 text(str):节点内容与content互斥

Source code in lazyllm/tools/rag/doc_node.py
@reset_on_pickle(('_lock', threading.Lock))
class DocNode:
    """
在指定的文档上执行设定的任务。
Args:
    uid(str): 唯一标识符。
    content(Union[str, List[Any]]):节点内容
    group(str):文档组名
    embedding(Dict[str, List[float]]):嵌入向量字典
    parent(Union[str, "DocNode"]):父节点引用
    store:存储表示
    node_groups(Dict[str, Dict]):节点存储组
    metadata(Dict[str, Any]):节点级元数据
    global_metadata(Dict[str, Any]):文档级元数据
    text(str):节点内容与content互斥
"""
    def __init__(self, uid: Optional[str] = None, content: Optional[Union[str, List[Any]]] = None,
                 group: Optional[str] = None, embedding: Optional[Dict[str, List[float]]] = None,
                 parent: Optional[Union[str, "DocNode"]] = None, store=None,
                 node_groups: Optional[Dict[str, Dict]] = None, metadata: Optional[Dict[str, Any]] = None,
                 global_metadata: Optional[Dict[str, Any]] = None, *, text: Optional[str] = None):
        if text and content:
            raise ValueError('`text` and `content` cannot be set at the same time.')
        if not content and not text: content = ''
        self._uid: str = uid if uid else str(uuid.uuid4())
        self._content: Optional[Union[str, List[Any]]] = content if content is not None else text
        self._group: Optional[str] = group
        self._embedding: Optional[Dict[str, List[float]]] = embedding or {}
        # metadata: the chunk's meta
        self._metadata: Dict[str, Any] = metadata or {}
        # Global metadata: the file's global metadata (higher level)
        self._global_metadata = global_metadata or {}
        # Metadata keys that are excluded from text for the embed model.
        self._excluded_embed_metadata_keys: List[str] = []
        # Metadata keys that are excluded from text for the LLM.
        self._excluded_llm_metadata_keys: List[str] = []
        # NOTE: node in parent should be id when stored in db (use store to recover): parent: 'uid'
        self._parent: Optional[Union[str, "DocNode"]] = parent
        self._children: Dict[str, List["DocNode"]] = defaultdict(list)
        self._children_loaded = False
        self._store = store
        self._node_groups: Dict[str, Dict] = node_groups or {}
        self._lock = threading.Lock()
        self._embedding_state = set()
        self.relevance_score = None
        self.similarity_score = None

    @property
    def uid(self) -> str:
        return self._uid

    @property
    def group(self) -> str:
        return self._group

    @property
    def text(self) -> str:
        if isinstance(self._content, str):
            return self._content
        elif isinstance(self._content, list):
            if unexcepted := set([type(ele) for ele in self._content if not isinstance(ele, str)]):
                raise TypeError(f"Found non-string element in content: {unexcepted}")
            return '\n'.join(self._content)
        else:
            raise TypeError(f"content type '{type(self._content)}' is neither a str nor a list")

    @property
    def embedding(self):
        return self._embedding

    @embedding.setter
    def embedding(self, v: Optional[Dict[str, List[float]]]):
        self._embedding = v

    def _load_from_store(self, group_name: str, uids: Union[str, List[str]]) -> List["DocNode"]:
        if not self._store or not uids:
            return []
        if isinstance(uids, str):
            uids = [uids]
        nodes = self._store.get_nodes(group_name=group_name, uids=uids,
                                      kb_id=self.global_metadata.get(RAG_KB_ID), display=True)
        for n in nodes:
            n._store = self._store
            n._node_groups = self._node_groups
        return nodes

    @property
    def parent(self) -> Optional["DocNode"]:
        if self._parent and isinstance(self._parent, str) and self._node_groups:
            parent_group = self._node_groups[self._group]["parent"]
            loaded = self._load_from_store(parent_group, self._parent)
            self._parent = loaded[0] if loaded else None
        return self._parent

    @parent.setter
    def parent(self, v: Optional["DocNode"]):
        self._parent = v

    @property
    def children(self) -> Dict[str, List["DocNode"]]:
        if not self._children_loaded and self._store and self._node_groups:
            self._children_loaded = True
            kb_id = self.global_metadata.get(RAG_KB_ID)
            doc_id = self.global_metadata.get(RAG_DOC_ID)
            c_groups = [grp for grp in self._node_groups.keys() if self._node_groups[grp]['parent'] == self._group]
            for grp in c_groups:
                if not self._store.is_group_active(grp):
                    continue
                nodes = self._store.get_nodes(group_name=grp, kb_id=kb_id, doc_ids=[doc_id])
                c_nodes = [n for n in nodes if n._parent in {self, self._uid}]
                self._children[grp] = c_nodes
                for n in self._children[grp]:
                    n._store = self._store
                    n._node_groups = self._node_groups
        return self._children

    @children.setter
    def children(self, v: Dict[str, List["DocNode"]]):
        self._children = v

    @property
    def root_node(self) -> "DocNode":
        node = self
        while isinstance(node._parent, DocNode):
            node = node._parent
        return node

    @property
    def is_root_node(self) -> bool:
        return (not self.parent)

    @property
    def global_metadata(self) -> Dict[str, Any]:
        return self.root_node._global_metadata

    @global_metadata.setter
    def global_metadata(self, global_metadata: Dict) -> None:
        self._global_metadata = global_metadata

    @property
    def metadata(self) -> Dict:
        return self._metadata

    @metadata.setter
    def metadata(self, metadata: Dict) -> None:
        self._metadata = metadata

    @property
    def excluded_embed_metadata_keys(self) -> List:
        return list(set(self.root_node._excluded_embed_metadata_keys + self._excluded_embed_metadata_keys))

    @excluded_embed_metadata_keys.setter
    def excluded_embed_metadata_keys(self, excluded_embed_metadata_keys: List) -> None:
        self._excluded_embed_metadata_keys = excluded_embed_metadata_keys

    @property
    def excluded_llm_metadata_keys(self) -> List:
        return list(set(self.root_node._excluded_llm_metadata_keys + self._excluded_llm_metadata_keys))

    @excluded_llm_metadata_keys.setter
    def excluded_llm_metadata_keys(self, excluded_llm_metadata_keys: List) -> None:
        self._excluded_llm_metadata_keys = excluded_llm_metadata_keys

    @property
    def docpath(self) -> str:
        return self.root_node.global_metadata.get(RAG_DOC_PATH, '')

    @docpath.setter
    def docpath(self, path):
        assert not self.parent, 'Only root node can set docpath'
        self.global_metadata[RAG_DOC_PATH] = str(path)

    def get_children_str(self) -> str:
        return str(
            {key: [node._uid for node in nodes] for key, nodes in self.children.items()}
        )

    def get_parent_id(self) -> str:
        return self.parent._uid if self.parent else ''

    def __str__(self) -> str:
        return (
            f"DocNode(id: {self._uid}, group: {self._group}, content: {self._content}) parent: {self.get_parent_id()}, "
            f"children: {self.get_children_str()}"
        )

    def __repr__(self) -> str:
        return str(self) if config["debug"] else f'<Node id={self._uid}>'

    def __eq__(self, other):
        if isinstance(other, DocNode):
            return self._uid == other._uid
        return False

    def __hash__(self):
        return hash(self._uid)

    def __getstate__(self):
        st = self.__dict__.copy()
        for attr in _pickle_blacklist:
            st[attr] = None
        return st

    def has_missing_embedding(self, embed_keys: Union[str, List[str]]) -> List[str]:
        """
检查缺失的嵌入向量
Args:
    embed_keys(Union[str, List[str]]): 目标键列表
"""
        if isinstance(embed_keys, str): embed_keys = [embed_keys]
        assert len(embed_keys) > 0, "The ebmed_keys to be checked must be passed in."
        if self.embedding is None: return embed_keys
        return [k for k in embed_keys if k not in self.embedding]

    def do_embedding(self, embed: Dict[str, Callable]) -> None:
        """
执行嵌入计算
Args:
    embed(Dict[str, Callable]): 目标嵌入对象
"""
        generate_embed = {k: e(self.get_text(MetadataMode.EMBED)) for k, e in embed.items()}
        with self._lock:
            self.embedding = self.embedding or {}
            self.embedding = {**self.embedding, **generate_embed}

    def check_embedding_state(self, embed_key: str) -> None:
        """
阻塞检查嵌入状态,确保异步嵌入计算完成
Args:
    embed_key(str): 目标键列表
"""
        while True:
            with self._lock:
                if not self.has_missing_embedding(embed_key):
                    self._embedding_state.discard(embed_key)
                    break
            time.sleep(1)

    def get_content(self) -> str:
        return self.get_text(MetadataMode.LLM)

    def get_metadata_str(self, mode: MetadataMode = MetadataMode.ALL) -> str:
        """Metadata info string."""
        if mode == MetadataMode.NONE:
            return ''

        metadata_keys = set(self.metadata.keys())
        if mode == MetadataMode.LLM:
            for key in self.excluded_llm_metadata_keys:
                if key in metadata_keys:
                    metadata_keys.remove(key)
        elif mode == MetadataMode.EMBED:
            for key in self.excluded_embed_metadata_keys:
                if key in metadata_keys:
                    metadata_keys.remove(key)

        return "\n".join([f"{key}: {self.metadata[key]}" for key in metadata_keys])

    def get_text(self, metadata_mode: MetadataMode = MetadataMode.NONE) -> str:
        """
组合元数据和内容
Args:
    metadata_mode: 与get_metadata_str中参数一致
"""
        metadata_str = self.get_metadata_str(metadata_mode).strip()
        if not metadata_str:
            return self.text if self.text else ''
        return f"{metadata_str}\n\n{self.text}".strip()

    def to_dict(self) -> Dict:
        """
转换为字典格式
"""
        return dict(content=self._content, embedding=self.embedding, metadata=self.metadata)

    def with_score(self, score):
        """
浅拷贝原节点并添加语义相关分数。
Args:
    score: 相关性得分
"""
        node = copy.copy(self)
        node.relevance_score = score
        return node

    def with_sim_score(self, score):
        """
浅拷贝原节点并添加相似度分数。
Args:
    score: 相似度得分
"""
        node = copy.copy(self)
        node.similarity_score = score
        return node

check_embedding_state(embed_key)

阻塞检查嵌入状态,确保异步嵌入计算完成 Args: embed_key(str): 目标键列表

Source code in lazyllm/tools/rag/doc_node.py
    def check_embedding_state(self, embed_key: str) -> None:
        """
阻塞检查嵌入状态,确保异步嵌入计算完成
Args:
    embed_key(str): 目标键列表
"""
        while True:
            with self._lock:
                if not self.has_missing_embedding(embed_key):
                    self._embedding_state.discard(embed_key)
                    break
            time.sleep(1)

do_embedding(embed)

执行嵌入计算 Args: embed(Dict[str, Callable]): 目标嵌入对象

Source code in lazyllm/tools/rag/doc_node.py
    def do_embedding(self, embed: Dict[str, Callable]) -> None:
        """
执行嵌入计算
Args:
    embed(Dict[str, Callable]): 目标嵌入对象
"""
        generate_embed = {k: e(self.get_text(MetadataMode.EMBED)) for k, e in embed.items()}
        with self._lock:
            self.embedding = self.embedding or {}
            self.embedding = {**self.embedding, **generate_embed}

get_metadata_str(mode=MetadataMode.ALL)

Metadata info string.

Source code in lazyllm/tools/rag/doc_node.py
def get_metadata_str(self, mode: MetadataMode = MetadataMode.ALL) -> str:
    """Metadata info string."""
    if mode == MetadataMode.NONE:
        return ''

    metadata_keys = set(self.metadata.keys())
    if mode == MetadataMode.LLM:
        for key in self.excluded_llm_metadata_keys:
            if key in metadata_keys:
                metadata_keys.remove(key)
    elif mode == MetadataMode.EMBED:
        for key in self.excluded_embed_metadata_keys:
            if key in metadata_keys:
                metadata_keys.remove(key)

    return "\n".join([f"{key}: {self.metadata[key]}" for key in metadata_keys])

get_text(metadata_mode=MetadataMode.NONE)

组合元数据和内容 Args: metadata_mode: 与get_metadata_str中参数一致

Source code in lazyllm/tools/rag/doc_node.py
    def get_text(self, metadata_mode: MetadataMode = MetadataMode.NONE) -> str:
        """
组合元数据和内容
Args:
    metadata_mode: 与get_metadata_str中参数一致
"""
        metadata_str = self.get_metadata_str(metadata_mode).strip()
        if not metadata_str:
            return self.text if self.text else ''
        return f"{metadata_str}\n\n{self.text}".strip()

has_missing_embedding(embed_keys)

检查缺失的嵌入向量 Args: embed_keys(Union[str, List[str]]): 目标键列表

Source code in lazyllm/tools/rag/doc_node.py
    def has_missing_embedding(self, embed_keys: Union[str, List[str]]) -> List[str]:
        """
检查缺失的嵌入向量
Args:
    embed_keys(Union[str, List[str]]): 目标键列表
"""
        if isinstance(embed_keys, str): embed_keys = [embed_keys]
        assert len(embed_keys) > 0, "The ebmed_keys to be checked must be passed in."
        if self.embedding is None: return embed_keys
        return [k for k in embed_keys if k not in self.embedding]

to_dict()

转换为字典格式

Source code in lazyllm/tools/rag/doc_node.py
    def to_dict(self) -> Dict:
        """
转换为字典格式
"""
        return dict(content=self._content, embedding=self.embedding, metadata=self.metadata)

with_score(score)

浅拷贝原节点并添加语义相关分数。 Args: score: 相关性得分

Source code in lazyllm/tools/rag/doc_node.py
    def with_score(self, score):
        """
浅拷贝原节点并添加语义相关分数。
Args:
    score: 相关性得分
"""
        node = copy.copy(self)
        node.relevance_score = score
        return node

with_sim_score(score)

浅拷贝原节点并添加相似度分数。 Args: score: 相似度得分

Source code in lazyllm/tools/rag/doc_node.py
    def with_sim_score(self, score):
        """
浅拷贝原节点并添加相似度分数。
Args:
    score: 相似度得分
"""
        node = copy.copy(self)
        node.similarity_score = score
        return node

lazyllm.tools.rag.doc_node.QADocNode

Bases: DocNode

问答文档节点类,用于存储问答对数据。

参数

query (str): 问题文本。 answer (str): 答案文本。 uid (str): 唯一标识符。 group (str): 文档组名。 embedding (Dict[str, List[float]]): 嵌入向量字典。 parent (DocNode): 父节点引用。 metadata (Dict[str, Any]): 节点级元数据。 global_metadata (Dict[str, Any]): 文档级元数据。 text (str): 节点内容,与query互斥。

Source code in lazyllm/tools/rag/doc_node.py
class QADocNode(DocNode):
    """问答文档节点类,用于存储问答对数据。

参数:
    query (str): 问题文本。
    answer (str): 答案文本。
    uid (str): 唯一标识符。
    group (str): 文档组名。
    embedding (Dict[str, List[float]]): 嵌入向量字典。
    parent (DocNode): 父节点引用。
    metadata (Dict[str, Any]): 节点级元数据。
    global_metadata (Dict[str, Any]): 文档级元数据。
    text (str): 节点内容,与query互斥。
"""
    def __init__(self, query: str, answer: str, uid: Optional[str] = None, group: Optional[str] = None,
                 embedding: Optional[Dict[str, List[float]]] = None, parent: Optional["DocNode"] = None,
                 metadata: Optional[Dict[str, Any]] = None, global_metadata: Optional[Dict[str, Any]] = None,
                 *, text: Optional[str] = None):
        super().__init__(uid, query, group, embedding, parent, metadata, global_metadata=global_metadata, text=text)
        self._answer = answer.strip()

    @property
    def answer(self) -> str:
        return self._answer

    def get_text(self, metadata_mode: MetadataMode = MetadataMode.NONE) -> str:
        """获取节点的文本内容。

参数:
    metadata_mode (MetadataMode): 元数据模式,默认为MetadataMode.NONE。
        当设置为MetadataMode.LLM时,返回格式化的问答对。
        其他模式下返回基类的文本格式。

返回值:
    str: 格式化后的文本内容。
"""
        if metadata_mode == MetadataMode.LLM:
            return f'query:\n{self.text}\nanswer\n{self._answer}'
        return super().get_text(metadata_mode)

get_text(metadata_mode=MetadataMode.NONE)

获取节点的文本内容。

参数

metadata_mode (MetadataMode): 元数据模式,默认为MetadataMode.NONE。 当设置为MetadataMode.LLM时,返回格式化的问答对。 其他模式下返回基类的文本格式。

返回值

str: 格式化后的文本内容。

Source code in lazyllm/tools/rag/doc_node.py
    def get_text(self, metadata_mode: MetadataMode = MetadataMode.NONE) -> str:
        """获取节点的文本内容。

参数:
    metadata_mode (MetadataMode): 元数据模式,默认为MetadataMode.NONE。
        当设置为MetadataMode.LLM时,返回格式化的问答对。
        其他模式下返回基类的文本格式。

返回值:
    str: 格式化后的文本内容。
"""
        if metadata_mode == MetadataMode.LLM:
            return f'query:\n{self.text}\nanswer\n{self._answer}'
        return super().get_text(metadata_mode)

lazyllm.tools.rag.doc_processor.DocumentProcessor

Bases: ModuleBase

文档处理器类,用于管理文档的添加、删除和更新操作。

Parameters:

  • server (bool, default: True ) –

    是否以服务器模式运行。默认为True。

  • port (Optional[int], default: None ) –

    服务器端口号。默认为None。

  • url (Optional[str], default: None ) –

    远程服务URL。默认为None。

说明: - 支持异步处理文档任务 - 提供文档元数据更新功能 - 支持任务状态回调通知 - 可配置数据库存储

Examples:

```python
# Create local document processor
processor = DocumentProcessor(server=False)

# Create server mode document processor
processor = DocumentProcessor(server=True, port=8080)

# Create remote document processor
processor = DocumentProcessor(url="http://remote-server:8080")
```
Source code in lazyllm/tools/rag/doc_processor.py
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
class DocumentProcessor(ModuleBase):
    """
文档处理器类,用于管理文档的添加、删除和更新操作。

Args:
    server (bool): 是否以服务器模式运行。默认为True。
    port (Optional[int]): 服务器端口号。默认为None。
    url (Optional[str]): 远程服务URL。默认为None。

**说明:**
- 支持异步处理文档任务
- 提供文档元数据更新功能
- 支持任务状态回调通知
- 可配置数据库存储


Examples:

    ```python
    # Create local document processor
    processor = DocumentProcessor(server=False)

    # Create server mode document processor
    processor = DocumentProcessor(server=True, port=8080)

    # Create remote document processor
    processor = DocumentProcessor(url="http://remote-server:8080")
    ```
    """

    class Impl():
        def __init__(self, server: bool):
            self._processors: Dict[str, _Processor] = dict()
            self._server = server
            self._inited = False
            try:
                self._feedback_url = config['process_feedback_service']
                self._path_prefix = config['process_path_prefix']
            except Exception as e:
                LOG.warning(f"Failed to get config: {e}, use env variables instead")
                self._feedback_url = os.getenv("PROCESS_FEEDBACK_SERVICE", None)
                self._path_prefix = os.getenv("PROCESS_PATH_PREFIX", None)

        def _init_components(self, server: bool):
            if server and not self._inited:
                self._task_queue = queue.Queue()
                self._tasks = {}    # running tasks
                self._pending_task_ids = set()  # pending tasks
                self._add_executor = ThreadPoolExecutor(max_workers=4)
                self._add_futures = {}
                self._delete_executor = ThreadPoolExecutor(max_workers=4)
                self._update_executor = ThreadPoolExecutor(max_workers=4)
                self._update_futures = {}

                self._engines: dict[str, Engine] = {}
                self._inspectors: dict[str, inspect] = {}

                self._worker_thread = threading.Thread(target=self._worker, daemon=True)
                self._worker_thread.start()
            self._inited = True
            LOG.info(f"[DocumentProcessor] init done. feedback {self._feedback_url}, prefix {self._path_prefix}")

        def register_algorithm(self, name: str, store: _DocumentStore, reader: ReaderBase,
                               node_groups: Dict[str, Dict], display_name: Optional[str] = None,
                               description: Optional[str] = None, force_refresh: bool = False):
            self._init_components(server=self._server)
            if name in self._processors and not force_refresh:
                LOG.warning(f'There is already a processor with the same name {name}!')
                return
            self._processors[name] = _Processor(store, reader, node_groups, display_name, description)
            LOG.info(f'Processor {name} registered!')

        def drop_algorithm(self, name: str, clean_db: bool = False) -> None:
            if name not in self._processors:
                LOG.warning(f'Processor {name} not found!')
                return
            self._processors.pop(name)

        def _get_engine(self, url) -> Engine:
            if url not in self._engines:
                engine = create_engine(url, echo=False, pool_pre_ping=True)
                self._engines[url] = engine
                self._inspectors[url] = inspect(engine)
            return self._engines[url]

        def _get_inspector(self, url):
            self._get_engine(url=url)
            return self._inspectors[url]

        def _get_url_from_db_info(self, db_info: DBInfo):
            return (f"mysql+pymysql://{db_info.user}:{db_info.password}"
                    f"@{db_info.host}:{db_info.port}/{db_info.db_name}"
                    "?charset=utf8mb4")

        def create_table(self, db_info: DBInfo):
            if db_info.db_type == "mysql":
                try:
                    url = self._get_url_from_db_info(db_info)
                    engine = self._get_engine(url=url)
                    inspector = self._get_inspector(url=url)
                    tbl = db_info.table_name
                    schema = db_info.db_name

                    if not inspector.has_table(tbl, schema=schema):
                        metadata = MetaData()
                        table = Table(tbl, metadata, Column('document_id', String(255), primary_key=True),
                                      Column('file_name', String(255), nullable=False),
                                      Column('file_path', String(255), nullable=False),
                                      Column('description', String(255), nullable=True),
                                      Column('creater', String(255), nullable=False),
                                      Column('dataset_id', String(255), nullable=False),
                                      Column('tags', JSON, nullable=True),
                                      Column('created_at', TIMESTAMP, server_default=text("CURRENT_TIMESTAMP")))
                        metadata.create_all(engine, tables=[table])
                        LOG.info(f"Created table `{tbl}` in `{schema}`")
                except Exception as e:
                    LOG.error(f"Failed to create table `{tbl}` in `{schema}`: {e}")
                    return
            else:
                raise ValueError(f"Unsupported database type: {db_info.db_type}")

        def operate_db(self, db_info: DBInfo, operation: str,
                       file_infos: List[FileInfo] = None, params: Dict = None) -> None:
            db_type = db_info.db_type
            if db_type not in DB_TYPES:
                raise ValueError(f"Unsupported db_type: {db_type}")
            url = self._get_url_from_db_info(db_info)
            engine = self._get_engine(url=url)
            if operation == 'upsert':
                self._upsert_records(engine, db_info, file_infos)
            elif operation == 'delete':
                self._delete_records(engine, db_info, params)
            else:
                raise ValueError(f"Unsupported operation: {operation}")

        def _upsert_records(self, engine, db_info, file_infos):
            table_name = db_info['table_name']
            metadata = MetaData()
            metadata.reflect(bind=engine, only=[table_name])
            table = metadata.tables[table_name]
            with engine.begin() as conn:
                for file_info in file_infos:
                    document_id = file_info.get("doc_id")
                    file_path = file_info.get("file_path")
                    if not document_id or not file_path:
                        raise ValueError(f"Invalid file_info: {file_info}")

                    raw_infos = {"document_id": document_id, "file_name": os.path.basename(file_path),
                                 "file_path": file_path, "description": file_info["metadata"].get("description", None),
                                 "creater": file_info["metadata"].get("creater", None),
                                 "dataset_id": file_info["metadata"].get(RAG_KB_ID, None),
                                 "tags": file_info["metadata"].get("tags", []) or []}
                    infos = {}
                    for k, v in raw_infos.items():
                        if v is None:
                            continue
                        if isinstance(v, str) and not v.strip():
                            continue
                        if isinstance(v, (list, dict)) and not v:
                            continue
                        infos[k] = v
                    if "document_id" not in infos:
                        infos["document_id"] = document_id

                    stmt = mysql_insert(table).values(**infos)
                    update_dict = {k: stmt.inserted[k] for k in infos if k != 'document_id'}
                    upsert_stmt = stmt.on_duplicate_key_update(**update_dict)
                    conn.execute(upsert_stmt)

        def _delete_records(self, engine, db_info, params):
            table_name = db_info['table_name']
            metadata = MetaData()
            metadata.reflect(bind=engine, only=[table_name])
            table = metadata.tables[table_name]

            with engine.begin() as conn:  # 自动提交或回滚事务
                doc_ids = params.get("doc_ids", [])
                for document_id in doc_ids:
                    stmt = delete(table).where(table.c.document_id == document_id)
                    conn.execute(stmt)

        @app.get('/algo/list')
        async def get_algo_list(self) -> None:
            res = []
            for algo_id, processor in self._processors.items():
                res.append({"algo_id": algo_id, "display_name": processor._display_name,
                            "description": processor._description})
            return BaseResponse(code=200, msg='success', data=res)

        @app.get('/group/info')
        async def get_group_info(self, algo_id: str) -> None:
            if algo_id not in self._processors:
                return BaseResponse(code=400, msg=f"Invalid algo_id {algo_id}")
            processor = self._processors[algo_id]
            infos = []
            for group_name in processor._store.activated_groups():
                if group_name in processor._node_groups:
                    group_info = {"name": group_name, "type": processor._node_groups[group_name].get('group_type'),
                                  "display_name": processor._node_groups[group_name].get('display_name')}
                    infos.append(group_info)
            LOG.info(f"Get group info for {algo_id} success with {infos}")
            return BaseResponse(code=200, msg='success', data=infos)

        @app.post('/doc/add')
        async def async_add_doc(self, request: AddDocRequest):
            LOG.info(f"Add doc for {request.model_dump_json()}")
            task_id = request.task_id
            algo_id = request.algo_id
            file_infos = request.file_infos
            db_info = request.db_info
            feedback_url = request.feedback_url
            if algo_id not in self._processors:
                return BaseResponse(code=400, msg=f"Invalid algo_id {algo_id}")
            if task_id in self._pending_task_ids or task_id in self._tasks:
                return BaseResponse(code=400, msg=f'The task {task_id} already exists in queue', data=None)
            if self._path_prefix:
                for file_info in file_infos:
                    file_info.file_path = create_file_path(path=file_info.file_path, prefix=self._path_prefix)

            params = {"file_infos": file_infos, "db_info": db_info, "feedback_url": feedback_url}
            if ENABLE_DB:
                self.create_table(db_info=db_info)

            self._task_queue.put(('add', algo_id, task_id, params))
            self._pending_task_ids.add(task_id)
            return BaseResponse(code=200, msg='task submit successfully', data={"task_id": task_id})

        @app.post('/doc/meta/update')
        async def async_update_meta(self, request: UpdateMetaRequest):
            LOG.info(f"update doc meta for {request.model_dump_json()}")
            algo_id = request.algo_id
            file_infos = request.file_infos
            db_info = request.db_info

            if algo_id not in self._processors:
                return BaseResponse(code=400, msg=f"Invalid algo_id {algo_id}")

            for file_info in file_infos:
                doc_id = file_info.doc_id
                metadata = file_info.metadata
                old_fut = self._update_futures.get(doc_id)
                if old_fut and not old_fut.done():
                    cancelled = old_fut.cancel()
                    LOG.info(f"Canceled previous update for {doc_id}: {cancelled}")

                new_fut = self._update_executor.submit(self._processors[algo_id].update_doc_meta, doc_id=doc_id,
                                                       metadata=metadata)

                self._update_futures[doc_id] = new_fut

                def _cleanup(fut, doc_id=doc_id):
                    if self._update_futures.get(doc_id) is fut:
                        del self._update_futures[doc_id]
                new_fut.add_done_callback(_cleanup)
                if ENABLE_DB:
                    new_fut.add_done_callback(
                        lambda fut, dbi=db_info, fi=file_info: self.operate_db(dbi, 'upsert', file_infos=[fi]))

            return BaseResponse(code=200, msg='success')

        @app.delete('/doc/delete')
        async def async_delete_doc(self, request: DeleteDocRequest) -> None:
            LOG.info(f"Del doc for {request.model_dump_json()}")
            algo_id = request.algo_id
            dataset_id = request.dataset_id
            doc_ids = request.doc_ids
            db_info = request.db_info

            if algo_id not in self._processors:
                return BaseResponse(code=400, msg=f"Invalid algo_id {algo_id}")

            task_id = str(uuid.uuid4())
            self._task_queue.put(('delete', algo_id, task_id,
                                  {"dataset_id": dataset_id, "doc_ids": doc_ids, "db_info": db_info}))
            self._pending_task_ids.add(task_id)
            return BaseResponse(code=200, msg='task submit successfully', data={"task_id": task_id})

        @app.post('/doc/cancel')
        async def cancel_task(self, request: CancelDocRequest):
            task_id = request.task_id
            if task_id in self._pending_task_ids:
                self._pending_task_ids.remove(task_id)
                status = 1
            elif task_id in self._tasks:
                future = self._tasks.get(task_id)
                if future and not future.done():
                    cancelled = future.cancel()
                    status = 1 if cancelled else 0
                    if cancelled:
                        self._tasks.pop(task_id, None)
                else:
                    status = 0
            return BaseResponse(code=200, msg="success" if status else "failed",
                                data={"task_id": task_id, "status": status})

        def _send_status_message(self, task_id: str, callback_path: str, success: bool,
                                 error_code: str = "", error_msg: str = ""):
            if self._feedback_url:
                try:
                    full_url = self._feedback_url + callback_path
                    payload = {"task_id": task_id, "status": 1 if success else 0, "error_code": error_code,
                               "error_msg": error_msg}
                    headers = {"Content-Type": "application/json"}
                    res = None
                    for wait_time in fibonacci_backoff(max_retries=3):
                        try:
                            res = requests.post(full_url, json=payload, headers=headers, timeout=5)
                            if res.status_code == 200:
                                break
                            LOG.warning(
                                f"Task-{task_id}: Unexpected status {res.status_code}, retrying in {wait_time}s…")
                        except Exception as e:
                            LOG.error(f"Task-{task_id}: Request failed: {e}, retrying in {wait_time}s…")
                        time.sleep(wait_time)

                    if res is None:
                        raise RuntimeError("Failed to send feedback—no response received after retries")
                    res.raise_for_status()
                except Exception as e:
                    LOG.error(f"Task-{task_id}: Failed to send feedback to {full_url}: {e}")
            else:
                LOG.error("process_feedback_service is not set")

        def _exec_add_task(self, algo_id, task_id, params):
            try:
                file_infos: List[FileInfo] = params.get('file_infos')
                callback_path = params.get('feedback_url')
                db_info: DBInfo = params.get('db_info')

                input_files = []
                ids = []
                metadatas = []

                reparse_group = None
                reparse_doc_ids = []
                reparse_files = []
                reparse_metadatas = []

                for file_info in file_infos:
                    if file_info.reparse_group:
                        reparse_group = file_info.reparse_group
                        reparse_doc_ids.append(file_info.doc_id)
                        reparse_files.append(file_info.file_path)
                        reparse_metadatas.append(file_info.metadata)
                    else:
                        input_files.append(file_info.file_path)
                        ids.append(file_info.doc_id)
                        metadatas.append(file_info.metadata)

                if input_files:
                    future = self._add_executor.submit(self._processors[algo_id].add_doc, input_files=input_files,
                                                       ids=ids, metadatas=metadatas)
                    if ENABLE_DB:
                        future.add_done_callback(lambda fut: self.operate_db(db_info, 'upsert', file_infos=file_infos))
                elif reparse_group:
                    future = self._add_executor.submit(self._processors[algo_id].reparse, group_name=reparse_group,
                                                       doc_ids=reparse_doc_ids, doc_paths=reparse_files,
                                                       metadatas=reparse_metadatas)
                else:
                    LOG.error(
                        f"Task-{task_id}: add task error, no input files {input_files} or reparse group {reparse_group}"
                    )
                self._tasks[task_id] = (future, callback_path)
                self._pending_task_ids.remove(task_id)
            except Exception as e:
                LOG.error(f"Task-{task_id}: add task error {e}")

        def _exec_delete_task(self, algo_id, task_id, params):
            dataset_id = params.get("dataset_id")
            doc_ids = params.get("doc_ids")
            future = self._delete_executor.submit(
                self._processors[algo_id].delete_doc, dataset_id=dataset_id, doc_ids=doc_ids
            )
            if ENABLE_DB and params.get("db_info") is not None:
                db_info = params.get("db_info")
                future.add_done_callback(lambda fut: self.operate_db(db_info, 'delete', params=params))
            self._tasks[task_id] = (future, None)
            self._pending_task_ids.remove(task_id)

        def _worker(self):  # noqa: C901
            while True:
                try:
                    task_type, algo_id, task_id, params = self._task_queue.get(timeout=1)
                    if task_id not in self._pending_task_ids:
                        continue
                    if task_type == 'add':
                        self._exec_add_task(algo_id=algo_id, task_id=task_id, params=params)
                    elif task_type == 'delete':
                        self._exec_delete_task(algo_id=algo_id, task_id=task_id, params=params)
                    time.sleep(0.2)
                except queue.Empty:
                    task_need_pop = []
                    for task_id, (future, callback_path) in self._tasks.items():
                        if future.done():
                            task_need_pop.append(task_id)
                            ex = future.exception()
                            if callback_path and not ex:
                                self._send_status_message(task_id=task_id, callback_path=callback_path, success=True,
                                                          error_code="", error_msg="")
                            elif callback_path and ex:
                                self._send_status_message(task_id=task_id, callback_path=callback_path, success=False,
                                                          error_code=type(ex).__name__, error_msg=str(ex))
                                LOG.error(f"task {task_id} failed: {str(ex)}")
                            elif ex:
                                LOG.error(f"task {task_id} failed: {str(ex)}")
                    for task_id in task_need_pop:
                        self._tasks.pop(task_id)
                        LOG.info(f"task {task_id} done")
                    time.sleep(5)

        def __call__(self, func_name: str, *args, **kwargs):
            return getattr(self, func_name)(*args, **kwargs)

    def __init__(self, server: bool = True, port: int = None, url: str = None):
        super().__init__()
        if not url:
            self._impl = DocumentProcessor.Impl(server=server)
            if server:
                self._impl = ServerModule(self._impl, port=port)
        else:
            self._impl = UrlModule(url=ensure_call_endpoint(url))

    def _dispatch(self, method: str, *args, **kwargs):
        impl = self._impl
        if isinstance(impl, ServerModule):
            impl._call(method, *args, **kwargs)
        else:
            getattr(impl, method)(*args, **kwargs)

    def register_algorithm(self, name: str, store: _DocumentStore, reader: ReaderBase, node_groups: Dict[str, Dict],
                           display_name: Optional[str] = None, description: Optional[str] = None,
                           force_refresh: bool = False, **kwargs):
        """
注册算法到文档处理器。

Args:
    name (str): 算法名称,作为唯一标识符。
    store (StoreBase): 存储实例,用于管理文档数据。
    reader (ReaderBase): 读取器实例,用于解析文档内容。
    node_groups (Dict[str, Dict]): 节点组配置信息。
    force_refresh (bool): 是否强制刷新已存在的算法。默认为False。

**说明:**
- 如果算法名称已存在且force_refresh为False,将跳过注册
- 注册成功后可以使用该算法处理文档


Examples:

    ```python
    from lazyllm.rag import DocumentProcessor, FileStore, PDFReader

    # Create storage and reader instances
    store = FileStore(path="./data")
    reader = PDFReader()

    # Define node group configuration
    node_groups = {
        "text": {"transform": "text", "parent": "root"},
        "summary": {"transform": "summary", "parent": "text"}
    }

    # Register algorithm
    processor = DocumentProcessor()
    processor.register_algorithm(
        name="pdf_processor",
        store=store,
        reader=reader,
        node_groups=node_groups
    )
    ```
    """
        self._dispatch("register_algorithm", name, store, reader, node_groups,
                       display_name, description, force_refresh, **kwargs)

    def drop_algorithm(self, name: str, clean_db: bool = False) -> None:
        """
从文档处理器中移除指定算法。

Args:
    name (str): 要移除的算法名称。
    clean_db (bool): 是否清理相关数据库数据。默认为False。

**说明:**
- 如果算法名称不存在,将输出警告信息
- 移除后该算法将无法继续使用


Examples:

    ```python
    # Remove algorithm
    processor.drop_algorithm("pdf_processor")

    # Remove algorithm and clean database
    processor.drop_algorithm("pdf_processor", clean_db=True)
    ```
    """
        return self._dispatch("drop_algorithm", name, clean_db)

drop_algorithm(name, clean_db=False)

从文档处理器中移除指定算法。

Parameters:

  • name (str) –

    要移除的算法名称。

  • clean_db (bool, default: False ) –

    是否清理相关数据库数据。默认为False。

说明: - 如果算法名称不存在,将输出警告信息 - 移除后该算法将无法继续使用

Examples:

```python
# Remove algorithm
processor.drop_algorithm("pdf_processor")

# Remove algorithm and clean database
processor.drop_algorithm("pdf_processor", clean_db=True)
```
Source code in lazyllm/tools/rag/doc_processor.py
    def drop_algorithm(self, name: str, clean_db: bool = False) -> None:
        """
从文档处理器中移除指定算法。

Args:
    name (str): 要移除的算法名称。
    clean_db (bool): 是否清理相关数据库数据。默认为False。

**说明:**
- 如果算法名称不存在,将输出警告信息
- 移除后该算法将无法继续使用


Examples:

    ```python
    # Remove algorithm
    processor.drop_algorithm("pdf_processor")

    # Remove algorithm and clean database
    processor.drop_algorithm("pdf_processor", clean_db=True)
    ```
    """
        return self._dispatch("drop_algorithm", name, clean_db)

register_algorithm(name, store, reader, node_groups, display_name=None, description=None, force_refresh=False, **kwargs)

注册算法到文档处理器。

Parameters:

  • name (str) –

    算法名称,作为唯一标识符。

  • store (StoreBase) –

    存储实例,用于管理文档数据。

  • reader (ReaderBase) –

    读取器实例,用于解析文档内容。

  • node_groups (Dict[str, Dict]) –

    节点组配置信息。

  • force_refresh (bool, default: False ) –

    是否强制刷新已存在的算法。默认为False。

说明: - 如果算法名称已存在且force_refresh为False,将跳过注册 - 注册成功后可以使用该算法处理文档

Examples:

```python
from lazyllm.rag import DocumentProcessor, FileStore, PDFReader

# Create storage and reader instances
store = FileStore(path="./data")
reader = PDFReader()

# Define node group configuration
node_groups = {
    "text": {"transform": "text", "parent": "root"},
    "summary": {"transform": "summary", "parent": "text"}
}

# Register algorithm
processor = DocumentProcessor()
processor.register_algorithm(
    name="pdf_processor",
    store=store,
    reader=reader,
    node_groups=node_groups
)
```
Source code in lazyllm/tools/rag/doc_processor.py
    def register_algorithm(self, name: str, store: _DocumentStore, reader: ReaderBase, node_groups: Dict[str, Dict],
                           display_name: Optional[str] = None, description: Optional[str] = None,
                           force_refresh: bool = False, **kwargs):
        """
注册算法到文档处理器。

Args:
    name (str): 算法名称,作为唯一标识符。
    store (StoreBase): 存储实例,用于管理文档数据。
    reader (ReaderBase): 读取器实例,用于解析文档内容。
    node_groups (Dict[str, Dict]): 节点组配置信息。
    force_refresh (bool): 是否强制刷新已存在的算法。默认为False。

**说明:**
- 如果算法名称已存在且force_refresh为False,将跳过注册
- 注册成功后可以使用该算法处理文档


Examples:

    ```python
    from lazyllm.rag import DocumentProcessor, FileStore, PDFReader

    # Create storage and reader instances
    store = FileStore(path="./data")
    reader = PDFReader()

    # Define node group configuration
    node_groups = {
        "text": {"transform": "text", "parent": "root"},
        "summary": {"transform": "summary", "parent": "text"}
    }

    # Register algorithm
    processor = DocumentProcessor()
    processor.register_algorithm(
        name="pdf_processor",
        store=store,
        reader=reader,
        node_groups=node_groups
    )
    ```
    """
        self._dispatch("register_algorithm", name, store, reader, node_groups,
                       display_name, description, force_refresh, **kwargs)

lazyllm.tools.rag.dataReader.SimpleDirectoryReader

Bases: ModuleBase

模块化的文档目录读取器,继承自 ModuleBase,支持从文件系统读取多种格式的文档并转换为标准化的 DocNode 。 Args: input_dir (Optional[str]): 输入目录路径。与input_files二选一,不可同时指定。 input_files (Optional[List]):直接指定的文件列表。与input_dir二选一。 exclude (Optional[List]):需要排除的文件模式列表。 exclude_hidden (bool): 是否排除隐藏文件。 recursive (bool):是否递归读取子目录。 encoding (str):文本文件的编码格式。 required_exts (Optional[List[str]]):需要处理的文件扩展名白名单。 file_extractor (Optional[Dict[str, Callable]]):自定义文件阅读器字典。 fs (Optional[AbstractFileSystem]):自定义文件系统。 metadata_genf (Optional[Callable[[str], Dict]]):元数据生成函数,接收文件路径返回元数据字典。 num_files_limit (Optional[int]):最大读取文件数量限制。 return_trace (bool):是否返回处理过程追踪信息。 metadatas (Optional[Dict]):预定义的全局元数据字典。

Examples:

>>> import lazyllm
>>> from lazyllm.tools.dataReader import SimpleDirectoryReader
>>> reader = SimpleDirectoryReader(input_dir="yourpath/",recursive=True,exclude=["*.tmp"],required_exts=[".pdf", ".docx"])
>>> documents = reader.load_data()
Source code in lazyllm/tools/rag/dataReader.py
class SimpleDirectoryReader(ModuleBase):
    """
模块化的文档目录读取器,继承自 ModuleBase,支持从文件系统读取多种格式的文档并转换为标准化的 DocNode 。
Args:
    input_dir (Optional[str]): 输入目录路径。与input_files二选一,不可同时指定。
    input_files (Optional[List]):直接指定的文件列表。与input_dir二选一。
    exclude (Optional[List]):需要排除的文件模式列表。
    exclude_hidden (bool): 是否排除隐藏文件。
    recursive (bool):是否递归读取子目录。
    encoding (str):文本文件的编码格式。
    required_exts (Optional[List[str]]):需要处理的文件扩展名白名单。
    file_extractor (Optional[Dict[str, Callable]]):自定义文件阅读器字典。
    fs (Optional[AbstractFileSystem]):自定义文件系统。
    metadata_genf (Optional[Callable[[str], Dict]]):元数据生成函数,接收文件路径返回元数据字典。
    num_files_limit (Optional[int]):最大读取文件数量限制。
    return_trace (bool):是否返回处理过程追踪信息。
    metadatas (Optional[Dict]):预定义的全局元数据字典。


Examples:

    >>> import lazyllm
    >>> from lazyllm.tools.dataReader import SimpleDirectoryReader
    >>> reader = SimpleDirectoryReader(input_dir="yourpath/",recursive=True,exclude=["*.tmp"],required_exts=[".pdf", ".docx"])
    >>> documents = reader.load_data()
    """
    default_file_readers: Dict[str, Type[ReaderBase]] = {
        "*.pdf": PDFReader,
        "*.docx": DocxReader,
        "*.hwp": HWPReader,
        "*.pptx": PPTXReader,
        "*.ppt": PPTXReader,
        "*.pptm": PPTXReader,
        "*.gif": ImageReader,
        "*.jpeg": ImageReader,
        "*.jpg": ImageReader,
        "*.png": ImageReader,
        "*.webp": ImageReader,
        "*.ipynb": IPYNBReader,
        "*.epub": EpubReader,
        "*.md": MarkdownReader,
        "*.mbox": MboxReader,
        "*.csv": PandasCSVReader,
        "*.xls": PandasExcelReader,
        "*.xlsx": PandasExcelReader,
        "*.mp3": VideoAudioReader,
        "*.mp4": VideoAudioReader,
    }

    def __init__(self, input_dir: Optional[str] = None, input_files: Optional[List] = None,
                 exclude: Optional[List] = None, exclude_hidden: bool = True, recursive: bool = False,
                 encoding: str = "utf-8", filename_as_id: bool = False, required_exts: Optional[List[str]] = None,
                 file_extractor: Optional[Dict[str, Callable]] = None, fs: Optional[AbstractFileSystem] = None,
                 metadata_genf: Optional[Callable[[str], Dict]] = None, num_files_limit: Optional[int] = None,
                 return_trace: bool = False, metadatas: Optional[Dict] = None) -> None:
        super().__init__(return_trace=return_trace)

        if (not input_dir and not input_files) or (input_dir and input_files):
            raise ValueError("Must provide either `input_dir` or `input_files`.")

        self._fs = fs or get_default_fs()
        self._encoding = encoding

        self._exclude = exclude
        self._recursive = recursive
        self._exclude_hidden = exclude_hidden
        self._required_exts = required_exts
        self._num_files_limit = num_files_limit
        self._Path = Path if is_default_fs(self._fs) else PurePosixPath
        self._metadatas = metadatas

        if input_files:
            self._input_files = []
            for path in input_files:
                if not self._fs.isfile(path):
                    path = os.path.join(config['data_path'], path)
                    if not self._fs.isfile(path):
                        raise ValueError(f"File {path} does not exist.")
                input_file = self._Path(path)
                self._input_files.append(input_file)
        elif input_dir:
            if not self._fs.isdir(input_dir):
                raise ValueError(f"Directory {input_dir} does not exist.")
            self._input_dir = self._Path(input_dir)
            self._input_files = self._add_files(self._input_dir)

        self._file_extractor = file_extractor or {}

        self._metadata_genf = metadata_genf or _DefaultFileMetadataFunc(self._fs)
        if filename_as_id: LOG.warning('Argument `filename_as_id` for DataReader is no longer used')

    def _add_files(self, input_dir: Path) -> List[Path]:  # noqa: C901
        all_files = set()
        rejected_files = set()
        rejected_dirs = set()

        if self._exclude is not None:
            for excluded_pattern in self._exclude:
                if self._recursive:
                    excluded_glob = self._Path(input_dir) / self._Path("**") / excluded_pattern
                else:
                    excluded_glob = self._Path(input_dir) / excluded_pattern
                for file in self._fs.glob(str(excluded_glob)):
                    if self._fs.isdir(file):
                        rejected_dirs.add(self._Path(file))
                    else:
                        rejected_files.add(self._Path(file))

        file_refs: List[str] = []
        if self._recursive:
            file_refs = self._fs.glob(str(input_dir) + "/**/*")
        else:
            file_refs = self._fs.glob(str(input_dir) + "/*")

        for ref in file_refs:
            ref = self._Path(ref)
            is_dir = self._fs.isdir(ref)
            skip_hidden = self._exclude_hidden and self._is_hidden(ref)
            skip_bad_exts = (self._required_exts is not None and ref.suffix not in self._required_exts)
            skip_excluded = ref in rejected_files
            if not skip_excluded:
                if is_dir:
                    ref_parent_dir = ref
                else:
                    ref_parent_dir = self._fs._parent(ref)
                for rejected_dir in rejected_dirs:
                    if str(ref_parent_dir).startswith(str(rejected_dir)):
                        skip_excluded = True
                        LOG.warning(f"Skipping {ref} because it in parent dir "
                                    f"{ref_parent_dir} which is in {rejected_dir}.")
                        break

            if is_dir or skip_hidden or skip_bad_exts or skip_excluded:
                continue
            else:
                all_files.add(ref)

        new_input_files = sorted(all_files)

        if len(new_input_files) == 0:
            raise ValueError(f"No files found in {input_dir}.")
        if self._num_files_limit is not None and self._num_files_limit > 0:
            new_input_files = new_input_files[0: self._num_files_limit]

        LOG.debug(f"[SimpleDirectoryReader] Total files add: {len(new_input_files)}")

        LOG.info(f"input_files: {new_input_files}")
        return new_input_files

    def _is_hidden(self, path: Path) -> bool:
        return any(part.startswith(".") and part not in [".", ".."] for part in path.parts)

    def _exclude_metadata(self, documents: List[DocNode]) -> List[DocNode]:
        for doc in documents:
            doc._excluded_embed_metadata_keys.extend(
                ["file_name", "file_type", "file_size", "creation_date",
                 "last_modified_date", "last_accessed_date", "lazyllm_store_num"])
            doc._excluded_llm_metadata_keys.extend(
                ["file_name", "file_type", "file_size", "creation_date",
                 "last_modified_date", "last_accessed_date", "lazyllm_store_num"])
        return documents

    @staticmethod
    def load_file(input_file: Path, metadata_genf: Callable[[str], Dict], file_extractor: Dict[str, Callable],
                  encoding: str = "utf-8", pathm: PurePath = Path, fs: Optional[AbstractFileSystem] = None,
                  metadata: Optional[Dict] = None) -> List[DocNode]:
        # metadata priority: user > reader > metadata_genf
        user_metadata: Dict = metadata or {}
        metadata_generated: Dict = metadata_genf(str(input_file)) if metadata_genf else {}
        documents: List[DocNode] = []

        filename_lower = str(input_file).lower()

        for pattern, extractor in file_extractor.items():
            pt_lower = str(pathm(pattern)).lower()
            match_pattern = pt_lower if pt_lower.endswith("*") else os.path.join(str(pathm.cwd()).lower(), pt_lower)
            if pt_lower.startswith("*"):
                match_pattern = pt_lower
            else:
                base = str(pathm.cwd()).lower()
                match_pattern = os.path.join(base, pt_lower)

            if fnmatch.fnmatch(filename_lower, match_pattern):
                reader = extractor() if isinstance(extractor, type) else extractor
                kwargs = {'fs': fs} if fs and not is_default_fs(fs) else {}
                docs = reader(input_file, **kwargs)
                if isinstance(docs, DocNode): docs = [docs]
                for doc in docs:
                    metadata = metadata_generated.copy()
                    metadata.update(doc._global_metadata or {})
                    metadata.update(user_metadata)
                    doc._global_metadata = metadata

                if config['rag_filename_as_id']:
                    for i, doc in enumerate(docs):
                        doc._uid = f"{input_file!s}_index_{i}"
                documents.extend(docs)
                break
        else:
            if not config['use_fallback_reader']:
                LOG.warning(f'no pattern found for {input_file}! '
                            'If you want fallback to default Reader, set `LAZYLLM_USE_FALLBACK_READER=True`.')
                return documents
            fs = fs or get_default_fs()
            with fs.open(input_file, encoding=encoding) as f:
                try:
                    data = f.read().decode(encoding)
                    doc = DocNode(text=data, global_metadata=user_metadata)
                    documents.append(doc)
                except Exception:
                    LOG.error(f'no pattern found for {input_file} and it is not utf-8, skip it!')
        return documents

    def _load_data(self, show_progress: bool = False, num_workers: Optional[int] = None,
                   fs: Optional[AbstractFileSystem] = None) -> List[DocNode]:
        documents = []

        fs = fs or self._fs
        process_file = self._input_files
        file_readers = self._file_extractor.copy()
        for key, func in self.default_file_readers.items():
            if key not in file_readers: file_readers[key] = func

        if num_workers and num_workers >= 1:
            if num_workers > multiprocessing.cpu_count():
                LOG.warning("Specified num_workers exceed number of CPUs in the system. "
                            "Setting `num_workers` down to the maximum CPU count.")
            with multiprocessing.get_context("spawn").Pool(num_workers) as p:
                results = p.starmap(SimpleDirectoryReader.load_file,
                                    zip(process_file, repeat(self._metadata_genf), repeat(file_readers),
                                        repeat(self._encoding), repeat(self._Path),
                                        repeat(self._fs), self._metadatas or repeat(None)))
                documents = reduce(lambda x, y: x + y, results)
        else:
            if show_progress:
                process_file = tqdm(self._input_files, desc="Loading files", unit="file")
            for input_file, metadata in zip(process_file, self._metadatas or repeat(None)):
                documents.extend(
                    SimpleDirectoryReader.load_file(
                        input_file=input_file, metadata_genf=self._metadata_genf, file_extractor=file_readers,
                        encoding=self._encoding, pathm=self._Path, fs=self._fs, metadata=metadata))

        return self._exclude_metadata(documents)

    def forward(self, *args, **kwargs) -> List[DocNode]:
        return self._load_data(*args, **kwargs)

lazyllm.tools.rag.dataReader.FileReader

Bases: object

文件内容读取器,主要功能是将多种格式的输入文件转换为拼接后的纯文本内容。 Args: input_files (Optional[List]):直接指定的文件列表。

Examples:

>>> import lazyllm
>>> from lazyllm.tools.dataReader import FileReader
>>> reader = FileReader()
>>> content = reader("yourpath/")
Source code in lazyllm/tools/rag/dataReader.py
class FileReader(object):
    """
文件内容读取器,主要功能是将多种格式的输入文件转换为拼接后的纯文本内容。
Args:
    input_files (Optional[List]):直接指定的文件列表。


Examples:

    >>> import lazyllm
    >>> from lazyllm.tools.dataReader import FileReader
    >>> reader = FileReader()
    >>> content = reader("yourpath/") 
    """

    def __call__(self, input_files):
        file_list = _lazyllm_get_file_list(input_files)
        if isinstance(file_list, str) and file_list is not None:
            file_list = [file_list]
        if len(file_list) == 0:
            return []
        nodes = SimpleDirectoryReader(input_files=file_list)._load_data()
        txt = [node.get_text() for node in nodes]
        return "\n".join(txt)

lazyllm.tools.rag.web.DocWebModule

Bases: ModuleBase

文档Web界面模块,继承自ModuleBase,提供基于Web的文档管理交互界面。

Parameters:

  • doc_server (ServerModule) –

    文档服务模块实例,提供后端API支持

  • title (str, default: '文档管理演示终端' ) –

    界面标题,默认为"文档管理演示终端"

  • port (int / range / list, default: None ) –

    服务端口号或端口范围,默认为20800-20999

  • history (list, default: None ) –

    初始聊天历史记录,默认为空列表

  • text_mode (Mode, default: None ) –

    文本处理模式,默认为Mode.Dynamic(动态模式)

  • trace_mode (Mode, default: None ) –

    追踪模式,默认为Mode.Refresh(刷新模式)

类属性

Mode: 模式枚举类,包含: - Dynamic: 动态模式 - Refresh: 刷新模式 - Appendix: 附录模式

注意事项
  • 需要配合有效的doc_server实例使用
  • 端口冲突时会自动尝试范围内其他端口
  • 服务停止后会释放相关资源

Examples:

>>> import lazyllm
>>> from lazyllm.tools.rag.web import DocWebModule
>>> from lazyllm import
>>> doc_server = ServerModule(url="your_url")
>>> doc_web = DocWebModule(
>>>   doc_server=doc_server,
>>>   title="文档管理演示终端",
>>>   port=range(20800, 20805)  # 自动寻找可用端口)
>>> deploy_task = doc_web._get_deploy_tasks()
>>> deploy_task()  
>>> print(doc_web.url)
>>> doc_web.stop()
Source code in lazyllm/tools/rag/web.py
class DocWebModule(ModuleBase):
    """文档Web界面模块,继承自ModuleBase,提供基于Web的文档管理交互界面。

Args:
    doc_server (ServerModule): 文档服务模块实例,提供后端API支持
    title (str): 界面标题,默认为"文档管理演示终端"
    port (int/range/list): 服务端口号或端口范围,默认为20800-20999
    history (list): 初始聊天历史记录,默认为空列表
    text_mode (Mode): 文本处理模式,默认为Mode.Dynamic(动态模式)
    trace_mode (Mode): 追踪模式,默认为Mode.Refresh(刷新模式)

类属性:
    Mode: 模式枚举类,包含:
        - Dynamic: 动态模式
        - Refresh: 刷新模式
        - Appendix: 附录模式

注意事项:
    - 需要配合有效的doc_server实例使用
    - 端口冲突时会自动尝试范围内其他端口
    - 服务停止后会释放相关资源


Examples:
    >>> import lazyllm
    >>> from lazyllm.tools.rag.web import DocWebModule
    >>> from lazyllm import
    >>> doc_server = ServerModule(url="your_url")
    >>> doc_web = DocWebModule(
    >>>   doc_server=doc_server,
    >>>   title="文档管理演示终端",
    >>>   port=range(20800, 20805)  # 自动寻找可用端口)
    >>> deploy_task = doc_web._get_deploy_tasks()
    >>> deploy_task()  
    >>> print(doc_web.url)
    >>> doc_web.stop()
    """
    class Mode:
        Dynamic = 0
        Refresh = 1
        Appendix = 2

    def __init__(self, doc_server: ServerModule, title="文档管理演示终端", port=None,
                 history=None, text_mode=None, trace_mode=None) -> None:
        super().__init__()
        self.title = title
        self.port = port or range(20800, 20999)
        self.history = history or []
        self.trace_mode = trace_mode if trace_mode else DocWebModule.Mode.Refresh
        self.text_mode = text_mode if text_mode else DocWebModule.Mode.Dynamic
        self.doc_server = doc_server
        self._deploy_flag = lazyllm.once_flag()
        self.api_url = ""
        self.url = ""

    def _prepare(self, query, chat_history):
        if chat_history is None:
            chat_history = []
        return "", chat_history + [[query, None]]

    def _clear_history(self):
        return [], "", ""

    def _work(self):
        if isinstance(self.port, (range, tuple, list)):
            port = self._find_can_use_network_port()
        else:
            port = self.port
            assert self._verify_port_access(port), f"port {port} is occupied"

        self.api_url = self.doc_server._url.rsplit("/", 1)[0]
        self.web_ui = WebUi(self.api_url)
        self.demo = self.web_ui.create_ui()
        self.url = f'http://127.0.0.1:{port}'
        self.broadcast_url = f'http://0.0.0.0:{port}'

        self.demo.queue().launch(server_name="0.0.0.0", server_port=port, prevent_thread_lock=True)
        LOG.success('LazyLLM docwebmodule launched successfully: Running on: '
                    f'{self.broadcast_url}, local URL: {self.url}')

    def _get_deploy_tasks(self):
        return Pipeline(self._work)

    def _get_post_process_tasks(self):
        return Pipeline(self._print_url)

    def wait(self):
        """阻塞当前线程以保持Web界面运行,直到手动停止。

"""
        self.demo.block_thread()

    def stop(self):
        """停止Web界面服务并释放相关资源。

"""
        if self.demo:
            self.demo.close()
            del self.demo
            self.demo, self.url = None, ''

    def _find_can_use_network_port(self):
        for port in self.port:
            if self._verify_port_access(port):
                return port
        raise RuntimeError(
            f"The ports in the range {self.port} are all occupied. "
            "Please change the port range or release the relevant ports."
        )

    def _print_url(self):
        lazyllm.LOG.success(f"LazyLLM DocWebModule launched successfully: Running on local URL: {self.url}")

    def _verify_port_access(self, port):
        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
            result = s.connect_ex(("127.0.0.1", port))
            return result != 0

    def __repr__(self):
        return lazyllm.make_repr("Module", "DocWebModule")

stop()

停止Web界面服务并释放相关资源。

Source code in lazyllm/tools/rag/web.py
    def stop(self):
        """停止Web界面服务并释放相关资源。

"""
        if self.demo:
            self.demo.close()
            del self.demo
            self.demo, self.url = None, ''

wait()

阻塞当前线程以保持Web界面运行,直到手动停止。

Source code in lazyllm/tools/rag/web.py
    def wait(self):
        """阻塞当前线程以保持Web界面运行,直到手动停止。

"""
        self.demo.block_thread()

lazyllm.tools.WebModule

Bases: ModuleBase

WebModule是LazyLLM为开发者提供的基于Web的交互界面。在初始化并启动一个WebModule之后,开发者可以从页面上看到WebModule背后的模块结构,并将Chatbot组件的输入传输给自己开发的模块进行处理。 模块返回的结果和日志会直接显示在网页的“处理日志”和Chatbot组件上。除此之外,WebModule支持在网页上动态加入Checkbox或Text组件用于向模块发送额外的参数。 WebModule页面还提供“使用上下文”,“流式输出”和“追加输出”的Checkbox,可以用来改变页面和后台模块的交互方式。

WebModule.init_web(component_descs) -> gradio.Blocks 使用gradio库生成演示web页面,初始化session相关数据以便在不同的页面保存各自的对话和日志,然后使用传入的component_descs参数为页面动态添加Checkbox和Text组件,最后设置页面上的按钮和文本框的相应函数 之后返回整个页面。WebModule的__init__函数调用此方法生成页面。

Parameters:

  • component_descs (list) –

    用于动态向页面添加组件的列表。列表中的每个元素也是一个列表,其中包含5个元素,分别是组件对应的模块ID,模块名,组件名,组件类型(目前仅支持Checkbox和Text),组件默认值。

Examples:

>>> import lazyllm
>>> def func2(in_str, do_sample=True, temperature=0.0, *args, **kwargs):
...     return f"func2:{in_str}|do_sample:{str(do_sample)}|temp:{temperature}"
...
>>> m1=lazyllm.ActionModule(func2)
>>> m1.name="Module1"
>>> w = lazyllm.WebModule(m1, port=[20570, 20571, 20572], components={
...         m1:[('do_sample', 'Checkbox', True), ('temperature', 'Text', 0.1)]},
...                       text_mode=lazyllm.tools.WebModule.Mode.Refresh)
>>> w.start()
193703: 2024-06-07 10:26:00 lazyllm SUCCESS: ...
Source code in lazyllm/tools/webpages/webmodule.py
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
class WebModule(ModuleBase):
    """WebModule是LazyLLM为开发者提供的基于Web的交互界面。在初始化并启动一个WebModule之后,开发者可以从页面上看到WebModule背后的模块结构,并将Chatbot组件的输入传输给自己开发的模块进行处理。
模块返回的结果和日志会直接显示在网页的“处理日志”和Chatbot组件上。除此之外,WebModule支持在网页上动态加入Checkbox或Text组件用于向模块发送额外的参数。
WebModule页面还提供“使用上下文”,“流式输出”和“追加输出”的Checkbox,可以用来改变页面和后台模块的交互方式。

<span style="font-size: 20px;">&ensp;**`WebModule.init_web(component_descs) -> gradio.Blocks`**</span>
使用gradio库生成演示web页面,初始化session相关数据以便在不同的页面保存各自的对话和日志,然后使用传入的component_descs参数为页面动态添加Checkbox和Text组件,最后设置页面上的按钮和文本框的相应函数
之后返回整个页面。WebModule的__init__函数调用此方法生成页面。

Args:
    component_descs (list): 用于动态向页面添加组件的列表。列表中的每个元素也是一个列表,其中包含5个元素,分别是组件对应的模块ID,模块名,组件名,组件类型(目前仅支持Checkbox和Text),组件默认值。


Examples:
    >>> import lazyllm
    >>> def func2(in_str, do_sample=True, temperature=0.0, *args, **kwargs):
    ...     return f"func2:{in_str}|do_sample:{str(do_sample)}|temp:{temperature}"
    ...
    >>> m1=lazyllm.ActionModule(func2)
    >>> m1.name="Module1"
    >>> w = lazyllm.WebModule(m1, port=[20570, 20571, 20572], components={
    ...         m1:[('do_sample', 'Checkbox', True), ('temperature', 'Text', 0.1)]},
    ...                       text_mode=lazyllm.tools.WebModule.Mode.Refresh)
    >>> w.start()
    193703: 2024-06-07 10:26:00 lazyllm SUCCESS: ...
    """
    class Mode:
        Dynamic = 0
        Refresh = 1
        Appendix = 2

    def __init__(self, m: Any, *, components: Dict[Any, Any] = dict(), title: str = '对话演示终端',  # noqa B008
                 port: Optional[Union[int, range, tuple, list]] = None, history: List[Any] = [],  # noqa B006
                 text_mode: Optional[Mode] = None, trace_mode: Optional[Mode] = None, audio: bool = False,
                 stream: bool = False, files_target: Optional[Union[Any, List[Any]]] = None,
                 static_paths: Optional[Union[str, Path, List[Union[str, Path]]]] = None,
                 encode_files: bool = False, share: bool = False) -> None:
        super().__init__()
        # Set the static directory of gradio so that gradio can access local resources in the directory
        if isinstance(static_paths, (str, Path)):
            self._static_paths = [static_paths]
        elif isinstance(static_paths, list) and all(isinstance(p, (str, Path)) for p in static_paths):
            self._static_paths = static_paths
        elif static_paths is None:
            self._static_paths = []
        else:
            raise ValueError(f"static_paths only supported str, path or list types. Not supported {static_paths}")
        self.m = lazyllm.ActionModule(m) if isinstance(m, lazyllm.FlowBase) else m
        self.pool = lazyllm.ThreadPoolExecutor(max_workers=50)
        self.title = title
        self.port = port or range(20500, 20799)
        components = sum([[([k._module_id, k._module_name] + list(v)) for v in vs]
                         for k, vs in components.items()], [])
        self.ckeys = [[c[0], c[2]] for c in components]
        if isinstance(m, (OnlineChatModule, TrainableModule)) and not history:
            history = [m]
        self.history = [h._module_id for h in history]
        if trace_mode:
            LOG.warn('trace_mode is deprecated')
        self.text_mode = text_mode if text_mode else WebModule.Mode.Dynamic
        self.cach_path = self._set_up_caching()
        self.audio = audio
        self.stream = stream
        self.files_target = files_target if isinstance(files_target, list) or files_target is None else [files_target]
        self.encode_files = encode_files
        self.share = share
        self.demo = self.init_web(components)
        self.url = None
        signal.signal(signal.SIGINT, self._signal_handler)
        signal.signal(signal.SIGTERM, self._signal_handler)

    def _get_all_file_submodule(self):
        if self.files_target: return
        self.files_target = []
        self.for_each(
            lambda x: getattr(x, 'template_message', None),
            lambda x: self.files_target.append(x)
        )

    def _signal_handler(self, signum, frame):
        LOG.info(f"Signal {signum} received, terminating subprocess.")
        atexit._run_exitfuncs()
        sys.exit(0)

    def _set_up_caching(self):
        if 'GRADIO_TEMP_DIR' in os.environ:
            cach_path = os.environ['GRADIO_TEMP_DIR']
        else:
            cach_path = os.path.join(lazyllm.config['temp_dir'], 'gradio_cach')
            os.environ['GRADIO_TEMP_DIR'] = cach_path
        if not os.path.exists(cach_path):
            os.makedirs(cach_path)
        return cach_path

    def init_web(self, component_descs):
        """初始化 Web UI 页面。
该方法使用 Gradio 构建对话界面,并将组件绑定到事件,支持会话选择、流式输出、上下文控制、多模态输入等功能。该方法返回构建完成的 Gradio Blocks 对象。
Args:
    component_descs (List[Tuple]): 组件描述列表,每项为五元组 (module, group_name, name, component_type, value),
        例如:('MyModule', 'GroupA', 'use_cache', 'Checkbox', True)。
Returns:
    gr.Blocks: 构建好的 Gradio 页面对象,可用于 launch 启动 Web 服务。
"""
        gr.set_static_paths(self._static_paths)
        with gr.Blocks(css=css, title=self.title, analytics_enabled=False) as demo:
            sess_data = gr.State(value={
                'sess_titles': [''],
                'sess_logs': {},
                'sess_history': {},
                'sess_num': 1,
                'curr_sess': '',
                'frozen_query': '',
            })
            with gr.Row():
                with gr.Column(scale=3):
                    with gr.Row():
                        with lazyllm.config.temp('repr_show_child', True):
                            gr.Textbox(elem_id='module', interactive=False, show_label=True,
                                       label="模型结构", value=repr(self.m))
                    with gr.Row():
                        chat_use_context = gr.Checkbox(interactive=True, value=False, label="使用上下文")
                    with gr.Row():
                        stream_output = gr.Checkbox(interactive=self.stream, value=self.stream, label="流式输出")
                        text_mode = gr.Checkbox(interactive=(self.text_mode == WebModule.Mode.Dynamic),
                                                value=(self.text_mode != WebModule.Mode.Refresh), label="追加输出")
                    components = []
                    for _, gname, name, ctype, value in component_descs:
                        if ctype in ('Checkbox', 'Text'):
                            components.append(getattr(gr, ctype)(interactive=True, value=value, label=f'{gname}.{name}'))
                        elif ctype == 'Dropdown':
                            components.append(getattr(gr, ctype)(interactive=True, choices=value,
                                                                 label=f'{gname}.{name}'))
                        else:
                            raise KeyError(f'invalid component type: {ctype}')
                    with gr.Row():
                        dbg_msg = gr.Textbox(show_label=True, label='处理日志',
                                             elem_id='logging', interactive=False, max_lines=10)
                    clear_btn = gr.Button(value="🗑️  Clear history", interactive=True)
                with gr.Column(scale=6):
                    with gr.Row():
                        add_sess_btn = gr.Button("添加新会话")
                        sess_drpdn = gr.Dropdown(choices=sess_data.value['sess_titles'], label="选择会话:", value='')
                        del_sess_btn = gr.Button("删除当前会话")
                    chatbot = gr.Chatbot(height=700)
                    query_box = gr.MultimodalTextbox(show_label=False, placeholder='输入内容并回车!!!', interactive=True)
                    recordor = gr.Audio(sources=["microphone"], type="filepath", visible=self.audio)

            query_box.submit(self._init_session, [query_box, sess_data, recordor],
                                                 [sess_drpdn, chatbot, dbg_msg, sess_data, recordor], queue=True
                ).then(lambda: gr.update(interactive=False), None, query_box, queue=False
                ).then(lambda: gr.update(interactive=False), None, add_sess_btn, queue=False
                ).then(lambda: gr.update(interactive=False), None, sess_drpdn, queue=False
                ).then(lambda: gr.update(interactive=False), None, del_sess_btn, queue=False
                ).then(self._prepare, [query_box, chatbot, sess_data], [query_box, chatbot], queue=True
                ).then(self._respond_stream, [chat_use_context, chatbot, stream_output, text_mode] + components,
                                             [chatbot, dbg_msg], queue=chatbot
                ).then(lambda: gr.update(interactive=True), None, query_box, queue=False
                ).then(lambda: gr.update(interactive=True), None, add_sess_btn, queue=False
                ).then(lambda: gr.update(interactive=True), None, sess_drpdn, queue=False
                ).then(lambda: gr.update(interactive=True), None, del_sess_btn, queue=False)
            clear_btn.click(self._clear_history, [sess_data], outputs=[chatbot, query_box, dbg_msg, sess_data])

            sess_drpdn.change(self._change_session, [sess_drpdn, chatbot, dbg_msg, sess_data],
                                                    [sess_drpdn, chatbot, query_box, dbg_msg, sess_data])
            add_sess_btn.click(self._add_session, [chatbot, dbg_msg, sess_data],
                                                  [sess_drpdn, chatbot, query_box, dbg_msg, sess_data])
            del_sess_btn.click(self._delete_session, [sess_drpdn, sess_data],
                                                     [sess_drpdn, chatbot, query_box, dbg_msg, sess_data])
            recordor.change(self._sub_audio, recordor, query_box)
            return demo

    def _sub_audio(self, audio):
        if audio:
            return {'text': '', 'files': [audio]}
        else:
            return {}

    def _init_session(self, query, session, audio):
        audio = None
        session['frozen_query'] = query
        if session['curr_sess'] != '':  # remain unchanged.
            return gr.Dropdown(), gr.Chatbot(), gr.Textbox(), session, audio

        if "text" in query and query["text"] is not None:
            id_name = query['text']
        else:
            id_name = id(id_name)
        session['curr_sess'] = f"({session['sess_num']})  {id_name}"
        session['sess_num'] += 1
        session['sess_titles'][0] = session['curr_sess']

        session['sess_logs'][session['curr_sess']] = []
        session['sess_history'][session['curr_sess']] = []
        return gr.update(choices=session['sess_titles'], value=session['curr_sess']), [], '', session, audio

    def _add_session(self, chat_history, log_history, session):
        if session['curr_sess'] == '':
            LOG.warning('Cannot create new session while current session is empty.')
            return gr.Dropdown(), gr.Chatbot(), {}, gr.Textbox(), session

        self._save_history(chat_history, log_history, session)

        session['curr_sess'] = ''
        session['sess_titles'].insert(0, session['curr_sess'])
        return gr.update(choices=session['sess_titles'], value=session['curr_sess']), [], {}, '', session

    def _save_history(self, chat_history, log_history, session):
        if session['curr_sess'] in session['sess_titles']:
            session['sess_history'][session['curr_sess']] = chat_history
            session['sess_logs'][session['curr_sess']] = log_history

    def _change_session(self, session_title, chat_history, log_history, session):
        if session['curr_sess'] == '':  # new session
            return gr.Dropdown(), [], {}, '', session

        if session_title not in session['sess_titles']:
            LOG.warning(f'{session_title} is not an existing session title.')
            return gr.Dropdown(), gr.Chatbot(), {}, gr.Textbox(), session

        self._save_history(chat_history, log_history, session)

        session['curr_sess'] = session_title
        return (gr.update(choices=session['sess_titles'], value=session['curr_sess']),
                session['sess_history'][session['curr_sess']], {},
                session['sess_logs'][session['curr_sess']], session)

    def _delete_session(self, session_title, session):
        if session_title not in session['sess_titles']:
            LOG.warning(f'session {session_title} does not exist.')
            return gr.Dropdown(), session
        session['sess_titles'].remove(session_title)

        if session_title != '':
            del session['sess_history'][session_title]
            del session['sess_logs'][session_title]
            session['curr_sess'] = session_title
        else:
            session['curr_sess'] = 'dummy session'
            # add_session and change_session cannot accept an uninitialized session.
            # Here we need to imitate removal of a real session so that
            # add_session and change_session could skip saving chat history.

        if len(session['sess_titles']) == 0:
            return self._add_session(None, None, session)
        else:
            return self._change_session(session['sess_titles'][0], None, {}, session)

    def _prepare(self, query, chat_history, session):
        if not query.get('text', '') and not query.get('files', []):
            query = session['frozen_query']
        if chat_history is None:
            chat_history = []
        for x in query["files"]:
            chat_history.append([[x,], None])
        if "text" in query and query["text"]:
            chat_history.append([query['text'], None])
        return {}, chat_history

    def _respond_stream(self, use_context, chat_history, stream_output, append_text, *args):  # noqa C901
        try:
            # TODO: move context to trainable module
            files = []
            chat_history[-1][1], log_history = '', []
            for file in chat_history[::-1]:
                if file[-1]: break  # not current chat
                if isinstance(file[0], (tuple, list)):
                    files.append(file[0][0])
                elif isinstance(file[0], str) and file[0].startswith('lazyllm_img::'):  # Just for pytest
                    files.append(file[0][13:])
            if isinstance(chat_history[-1][0], str):
                string = chat_history[-1][0]
            else:
                string = ''
            if self.files_target is None and not self.encode_files:
                self._get_all_file_submodule()
            if self.encode_files and files:
                string = encode_query_with_filepaths(string, files)
            if files and self.files_target:
                for module in self.files_target:
                    assert isinstance(module, ModuleBase)
                    if module._module_id in globals['lazyllm_files']:
                        globals['lazyllm_files'][module._module_id].extend(files)
                    else:
                        globals['lazyllm_files'][module._module_id] = files
                string += f' ## Get attachments: {os.path.basename(files[-1])}'
            elif self.files_target:
                for module in self.files_target:
                    assert isinstance(module, ModuleBase)
                    globals['lazyllm_files'][module._module_id] = []
            input = string
            history = chat_history[:-1] if use_context and len(chat_history) > 1 else list()

            for k, v in zip(self.ckeys, args):
                if k[0] not in globals['global_parameters']: globals['global_parameters'][k[0]] = dict()
                globals['global_parameters'][k[0]][k[1]] = v

            if use_context:
                for h in self.history:
                    if h not in globals['chat_history']: globals['chat_history'][h] = list()
                    globals['chat_history'][h] = history

            if FileSystemQueue().size() > 0: FileSystemQueue().clear()
            kw = dict(stream_output=stream_output) if isinstance(self.m, (TrainableModule, OnlineChatModule)) else {}
            func_future = self.pool.submit(self.m, input, **kw)
            while True:
                if value := FileSystemQueue().dequeue():
                    chat_history[-1][1] += ''.join(value) if append_text else ''.join(value)
                    if stream_output: yield chat_history, ''
                elif value := FileSystemQueue.get_instance('lazy_error').dequeue():
                    log_history.append(''.join(value))
                elif value := FileSystemQueue.get_instance('lazy_trace').dequeue():
                    log_history.append(''.join(value))
                elif func_future.done(): break
                time.sleep(0.01)
            result = func_future.result()
            if FileSystemQueue().size() > 0: FileSystemQueue().clear()

            def get_log_and_message(s):
                if isinstance(s, dict):
                    s = s.get("message", {}).get("content", "")
                else:
                    try:
                        r = decode_query_with_filepaths(s)
                        if isinstance(r, str):
                            r = json.loads(r)
                        if 'choices' in r:
                            if "type" not in r["choices"][0] or (
                                    "type" in r["choices"][0] and r["choices"][0]["type"] != "tool_calls"):
                                delta = r["choices"][0]["delta"]
                                if "content" in delta:
                                    s = delta["content"]
                                else:
                                    s = ""
                        elif isinstance(r, dict) and 'files' in r and 'query' in r:
                            return r['query'], ''.join(log_history), r['files'] if len(r['files']) > 0 else None
                        else:
                            s = s
                    except (ValueError, KeyError, TypeError):
                        s = s
                    except Exception as e:
                        LOG.error(f"Uncaptured error `{e}` when parsing `{s}`, please contact us if you see this.")
                return s, "".join(log_history), None

            def contains_markdown_image(text: str):
                pattern = r"!\[.*?\]\((.*?)\)"
                return bool(re.search(pattern, text))

            def extract_img_path(text: str):
                pattern = r"!\[.*?\]\((.*?)\)"
                urls = re.findall(pattern, text)
                return urls

            file_paths = None
            if isinstance(result, (str, dict)):
                result, log, file_paths = get_log_and_message(result)
            if file_paths:
                for i, file_path in enumerate(file_paths):
                    suffix = os.path.splitext(file_path)[-1].lower()
                    file = None
                    if suffix in PIL.Image.registered_extensions().keys():
                        file = gr.Image(file_path)
                    elif suffix in ('.mp3', '.wav'):
                        file = gr.Audio(file_path)
                    elif suffix in ('.mp4'):
                        file = gr.Video(file_path)
                    else:
                        LOG.error(f'Not supported typr: {suffix}, for file: {file}')
                    if i == 0:
                        chat_history[-1][1] = file
                    else:
                        chat_history.append([None, file])
                if result:
                    chat_history.append([None, result])
            else:
                assert isinstance(result, str), f'Result should only be str, but got {type(result)}'
                show_result = result
                if contains_markdown_image(show_result):
                    urls = extract_img_path(show_result)
                    for url in urls:
                        suffix = os.path.splitext(url)[-1].lower()
                        if suffix in PIL.Image.registered_extensions().keys() and os.path.exists(url):
                            show_result = show_result.replace(url, "file=" + url)
                if result:
                    count = (len(match.group(1)) if (match := re.search(r'(\n+)$', result)) else 0) + len(result) + 1
                    if not (result in chat_history[-1][1][-count:]):
                        chat_history[-1][1] += "\n\n" + show_result
                    elif show_result != result:
                        chat_history[-1][1] = chat_history[-1][1].replace(result, show_result)
        except requests.RequestException as e:
            chat_history = None
            log = str(e)
        except Exception as e:
            chat_history = None
            log = f'{str(e)}\n--- traceback ---\n{traceback.format_exc()}'
            LOG.error(log)
        globals['chat_history'].clear()
        yield chat_history, log

    def _clear_history(self, session):
        session['sess_history'][session['curr_sess']] = []
        session['sess_logs'][session['curr_sess']] = []
        return [], {}, '', session

    def _work(self):
        if isinstance(self.port, (range, tuple, list)):
            port = self._find_can_use_network_port()
        else:
            port = self.port
            assert self._verify_port_access(port), f'port {port} is occupied'

        self.url = f'http://127.0.0.1:{port}'
        self.broadcast_url = f'http://0.0.0.0:{port}'

        self.demo.queue().launch(server_name="0.0.0.0", server_port=port, prevent_thread_lock=True, share=self.share)
        LOG.success('LazyLLM webmodule launched successfully: Running on: '
                    f'{self.broadcast_url}, local URL: {self.url}')

    def _update(self, *, mode=None, recursive=True):
        super(__class__, self)._update(mode=mode, recursive=recursive)
        self._work()
        return self

    def wait(self):
        """阻塞主线程,等待 Web 页面关闭。
该方法会阻塞当前线程直到 Web 页面(Gradio demo)被关闭,适用于部署后阻止程序提前退出的场景。
"""
        self.demo.block_thread()

    def stop(self):
        """关闭 Web 页面并清理资源。
如果 Web 页面已初始化,则关闭 Gradio demo,释放资源并重置 `demo` 与 `url` 属性。
"""
        if self.demo:
            self.demo.close()
            del self.demo
            self.demo, self.url = None, ''

    @property
    def status(self):
        return 'running' if self.url else 'waiting' if self.url is None else 'Cancelled'

    def __repr__(self):
        return lazyllm.make_repr('Module', 'Web', name=self._module_name, subs=[repr(self.m)])

    def _find_can_use_network_port(self):
        for port in self.port:
            if self._verify_port_access(port):
                return port
        raise RuntimeError(
            f'The ports in the range {self.port} are all occupied. '
            'Please change the port range or release the relevant ports.'
        )

    def _verify_port_access(self, port):
        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
            result = s.connect_ex(('127.0.0.1', port))
            return result != 0

init_web(component_descs)

初始化 Web UI 页面。 该方法使用 Gradio 构建对话界面,并将组件绑定到事件,支持会话选择、流式输出、上下文控制、多模态输入等功能。该方法返回构建完成的 Gradio Blocks 对象。 Args: component_descs (List[Tuple]): 组件描述列表,每项为五元组 (module, group_name, name, component_type, value), 例如:('MyModule', 'GroupA', 'use_cache', 'Checkbox', True)。 Returns: gr.Blocks: 构建好的 Gradio 页面对象,可用于 launch 启动 Web 服务。

Source code in lazyllm/tools/webpages/webmodule.py
    def init_web(self, component_descs):
        """初始化 Web UI 页面。
该方法使用 Gradio 构建对话界面,并将组件绑定到事件,支持会话选择、流式输出、上下文控制、多模态输入等功能。该方法返回构建完成的 Gradio Blocks 对象。
Args:
    component_descs (List[Tuple]): 组件描述列表,每项为五元组 (module, group_name, name, component_type, value),
        例如:('MyModule', 'GroupA', 'use_cache', 'Checkbox', True)。
Returns:
    gr.Blocks: 构建好的 Gradio 页面对象,可用于 launch 启动 Web 服务。
"""
        gr.set_static_paths(self._static_paths)
        with gr.Blocks(css=css, title=self.title, analytics_enabled=False) as demo:
            sess_data = gr.State(value={
                'sess_titles': [''],
                'sess_logs': {},
                'sess_history': {},
                'sess_num': 1,
                'curr_sess': '',
                'frozen_query': '',
            })
            with gr.Row():
                with gr.Column(scale=3):
                    with gr.Row():
                        with lazyllm.config.temp('repr_show_child', True):
                            gr.Textbox(elem_id='module', interactive=False, show_label=True,
                                       label="模型结构", value=repr(self.m))
                    with gr.Row():
                        chat_use_context = gr.Checkbox(interactive=True, value=False, label="使用上下文")
                    with gr.Row():
                        stream_output = gr.Checkbox(interactive=self.stream, value=self.stream, label="流式输出")
                        text_mode = gr.Checkbox(interactive=(self.text_mode == WebModule.Mode.Dynamic),
                                                value=(self.text_mode != WebModule.Mode.Refresh), label="追加输出")
                    components = []
                    for _, gname, name, ctype, value in component_descs:
                        if ctype in ('Checkbox', 'Text'):
                            components.append(getattr(gr, ctype)(interactive=True, value=value, label=f'{gname}.{name}'))
                        elif ctype == 'Dropdown':
                            components.append(getattr(gr, ctype)(interactive=True, choices=value,
                                                                 label=f'{gname}.{name}'))
                        else:
                            raise KeyError(f'invalid component type: {ctype}')
                    with gr.Row():
                        dbg_msg = gr.Textbox(show_label=True, label='处理日志',
                                             elem_id='logging', interactive=False, max_lines=10)
                    clear_btn = gr.Button(value="🗑️  Clear history", interactive=True)
                with gr.Column(scale=6):
                    with gr.Row():
                        add_sess_btn = gr.Button("添加新会话")
                        sess_drpdn = gr.Dropdown(choices=sess_data.value['sess_titles'], label="选择会话:", value='')
                        del_sess_btn = gr.Button("删除当前会话")
                    chatbot = gr.Chatbot(height=700)
                    query_box = gr.MultimodalTextbox(show_label=False, placeholder='输入内容并回车!!!', interactive=True)
                    recordor = gr.Audio(sources=["microphone"], type="filepath", visible=self.audio)

            query_box.submit(self._init_session, [query_box, sess_data, recordor],
                                                 [sess_drpdn, chatbot, dbg_msg, sess_data, recordor], queue=True
                ).then(lambda: gr.update(interactive=False), None, query_box, queue=False
                ).then(lambda: gr.update(interactive=False), None, add_sess_btn, queue=False
                ).then(lambda: gr.update(interactive=False), None, sess_drpdn, queue=False
                ).then(lambda: gr.update(interactive=False), None, del_sess_btn, queue=False
                ).then(self._prepare, [query_box, chatbot, sess_data], [query_box, chatbot], queue=True
                ).then(self._respond_stream, [chat_use_context, chatbot, stream_output, text_mode] + components,
                                             [chatbot, dbg_msg], queue=chatbot
                ).then(lambda: gr.update(interactive=True), None, query_box, queue=False
                ).then(lambda: gr.update(interactive=True), None, add_sess_btn, queue=False
                ).then(lambda: gr.update(interactive=True), None, sess_drpdn, queue=False
                ).then(lambda: gr.update(interactive=True), None, del_sess_btn, queue=False)
            clear_btn.click(self._clear_history, [sess_data], outputs=[chatbot, query_box, dbg_msg, sess_data])

            sess_drpdn.change(self._change_session, [sess_drpdn, chatbot, dbg_msg, sess_data],
                                                    [sess_drpdn, chatbot, query_box, dbg_msg, sess_data])
            add_sess_btn.click(self._add_session, [chatbot, dbg_msg, sess_data],
                                                  [sess_drpdn, chatbot, query_box, dbg_msg, sess_data])
            del_sess_btn.click(self._delete_session, [sess_drpdn, sess_data],
                                                     [sess_drpdn, chatbot, query_box, dbg_msg, sess_data])
            recordor.change(self._sub_audio, recordor, query_box)
            return demo

stop()

关闭 Web 页面并清理资源。 如果 Web 页面已初始化,则关闭 Gradio demo,释放资源并重置 demourl 属性。

Source code in lazyllm/tools/webpages/webmodule.py
    def stop(self):
        """关闭 Web 页面并清理资源。
如果 Web 页面已初始化,则关闭 Gradio demo,释放资源并重置 `demo` 与 `url` 属性。
"""
        if self.demo:
            self.demo.close()
            del self.demo
            self.demo, self.url = None, ''

wait()

阻塞主线程,等待 Web 页面关闭。 该方法会阻塞当前线程直到 Web 页面(Gradio demo)被关闭,适用于部署后阻止程序提前退出的场景。

Source code in lazyllm/tools/webpages/webmodule.py
    def wait(self):
        """阻塞主线程,等待 Web 页面关闭。
该方法会阻塞当前线程直到 Web 页面(Gradio demo)被关闭,适用于部署后阻止程序提前退出的场景。
"""
        self.demo.block_thread()

lazyllm.tools.CodeGenerator

Bases: ModuleBase

代码生成模块。

该模块基于用户提供的提示词生成代码,会根据提示内容自动选择中文或英文的系统提示词,并从输出中提取 Python 代码片段。

__init__(self, base_model, prompt="") 初始化代码生成器。

Parameters:

  • base_model (Union[str, TrainableModule, OnlineChatModuleBase]) –

    模型路径字符串,或已初始化的模型实例。

  • prompt (str, default: '' ) –

    用户自定义的代码生成提示词,可为中文或英文。

Examples:

>>> from lazyllm.components import CodeGenerator
>>> generator = CodeGenerator(base_model="deepseek-coder", prompt="写一个Python函数,计算斐波那契数列。")
>>> result = generator("请给出实现代码")
>>> print(result)
... def fibonacci(n):
...     if n <= 1:
...         return n
...     return fibonacci(n-1) + fibonacci(n-2)
Source code in lazyllm/tools/actors/code_generator.py
class CodeGenerator(ModuleBase):
    """代码生成模块。

该模块基于用户提供的提示词生成代码,会根据提示内容自动选择中文或英文的系统提示词,并从输出中提取 Python 代码片段。

`__init__(self, base_model, prompt="")`
初始化代码生成器。

Args:
    base_model (Union[str, TrainableModule, OnlineChatModuleBase]): 模型路径字符串,或已初始化的模型实例。
    prompt (str): 用户自定义的代码生成提示词,可为中文或英文。


Examples:
    >>> from lazyllm.components import CodeGenerator
    >>> generator = CodeGenerator(base_model="deepseek-coder", prompt="写一个Python函数,计算斐波那契数列。")
    >>> result = generator("请给出实现代码")
    >>> print(result)
    ... def fibonacci(n):
    ...     if n <= 1:
    ...         return n
    ...     return fibonacci(n-1) + fibonacci(n-2)
    """
    def __init__(
        self,
        base_model: Union[str, TrainableModule, OnlineChatModuleBase],
        prompt: str = "",
    ):
        super().__init__()
        self._prompt = self.choose_prompt(prompt).format(prompt=prompt)
        if isinstance(base_model, str):
            self._m = TrainableModule(base_model).start().prompt(self._prompt)
        else:
            self._m = base_model.share(self._prompt)

    def choose_prompt(self, prompt: str):
        """根据输入的提示文本内容选择合适的代码生成提示模板。  
如果提示中包含中文字符,则返回中文提示模板;否则返回英文提示模板。

Args:
    prompt (str): 输入的提示文本。

**Returns:**

- str: 选择的代码生成提示模板字符串。
"""
        # Use chinese prompt if intent elements have chinese character, otherwise use english version
        for ele in prompt:
            # chinese unicode range
            if "\u4e00" <= ele <= "\u9fff":
                return ch_code_generate_prompt
        return en_code_generate_prompt

    def forward(self, *args, **kw):
        res = self._m(*args, **kw)
        pattern = r"```python(.*?)\n```"
        matches = re.findall(pattern, res, re.DOTALL)
        if len(matches) > 0:
            return matches[0]
        return res

choose_prompt(prompt)

根据输入的提示文本内容选择合适的代码生成提示模板。
如果提示中包含中文字符,则返回中文提示模板;否则返回英文提示模板。

Parameters:

  • prompt (str) –

    输入的提示文本。

Returns:

  • str: 选择的代码生成提示模板字符串。
Source code in lazyllm/tools/actors/code_generator.py
    def choose_prompt(self, prompt: str):
        """根据输入的提示文本内容选择合适的代码生成提示模板。  
如果提示中包含中文字符,则返回中文提示模板;否则返回英文提示模板。

Args:
    prompt (str): 输入的提示文本。

**Returns:**

- str: 选择的代码生成提示模板字符串。
"""
        # Use chinese prompt if intent elements have chinese character, otherwise use english version
        for ele in prompt:
            # chinese unicode range
            if "\u4e00" <= ele <= "\u9fff":
                return ch_code_generate_prompt
        return en_code_generate_prompt

lazyllm.tools.ParameterExtractor

Bases: ModuleBase

参数提取模块。

该模块根据参数名称、类型、描述和是否必填,从文本中提取结构化参数,底层依赖语言模型实现。

__init__(self, base_model, param, type, description, require) 使用参数定义和模型初始化参数提取器。

Parameters:

  • base_model (Union[str, TrainableModule, OnlineChatModuleBase]) –

    用于参数提取的模型路径或模型实例。

  • param (list[str]) –

    需要提取的参数名称列表。

  • type (list[str]) –

    参数类型列表,如 "int"、"str"、"bool" 等。

  • description (list[str]) –

    每个参数的描述信息。

  • require (list[bool]) –

    每个参数是否为必填项的布尔列表。

Examples:

>>> from lazyllm.components import ParameterExtractor
>>> extractor = ParameterExtractor(
...     base_model="deepseek-chat",
...     param=["name", "age"],
...     type=["str", "int"],
...     description=["The user's name", "The user's age"],
...     require=[True, True]
... )
>>> result = extractor("My name is Alice and I am 25 years old.")
>>> print(result)
... ['Alice', 25]
Source code in lazyllm/tools/actors/parameter_extractor.py
class ParameterExtractor(ModuleBase):
    """参数提取模块。

该模块根据参数名称、类型、描述和是否必填,从文本中提取结构化参数,底层依赖语言模型实现。

`__init__(self, base_model, param, type, description, require)`
使用参数定义和模型初始化参数提取器。

Args:
    base_model (Union[str, TrainableModule, OnlineChatModuleBase]): 用于参数提取的模型路径或模型实例。
    param (list[str]): 需要提取的参数名称列表。
    type (list[str]): 参数类型列表,如 "int"、"str"、"bool" 等。
    description (list[str]): 每个参数的描述信息。
    require (list[bool]): 每个参数是否为必填项的布尔列表。


Examples:
    >>> from lazyllm.components import ParameterExtractor
    >>> extractor = ParameterExtractor(
    ...     base_model="deepseek-chat",
    ...     param=["name", "age"],
    ...     type=["str", "int"],
    ...     description=["The user's name", "The user's age"],
    ...     require=[True, True]
    ... )
    >>> result = extractor("My name is Alice and I am 25 years old.")
    >>> print(result)
    ... ['Alice', 25]
    """
    type_map = {
        int.__name__: int,
        str.__name__: str,
        float.__name__: float,
        bool.__name__: bool,
        list.__name__: list,
        dict.__name__: dict,
    }

    def __init__(
        self,
        base_model: Union[str, TrainableModule, OnlineChatModuleBase],
        param: list[str],
        type: list[str],
        description: list[str],
        require: list[bool],
    ):
        super().__init__()
        assert len(param) == len(type) == len(description) == len(require) > 0
        self._param_dict = {p: ParameterExtractor.type_map[t] for p, t in zip(param, type)}
        param_prompt = repr([dict(name=p, type=t, description=d, require=r)
                             for p, t, d, r in zip(param, type, description, require)])
        self._prompt = self.choose_prompt(param_prompt).format(prompt=param_prompt)
        if isinstance(base_model, str):
            self._m = TrainableModule(base_model).start().prompt(self._prompt)
        else:
            self._m = base_model.share(self._prompt)

    def choose_prompt(self, prompt: str):
        # Use chinese prompt if intent elements have chinese character, otherwise use english version
        for ele in prompt:
            # chinese unicode range
            if "\u4e00" <= ele <= "\u9fff":
                return ch_parameter_extractor_prompt
        return en_parameter_extractor_prompt

    def forward(self, *args, **kw):
        res = self._m(*args, **kw)
        pattern = r"```json(.*?)\n```"
        matches = re.findall(pattern, res, re.DOTALL)
        if len(matches) > 0:
            res = matches[0]
            res.strip()
            try:
                res = json.loads(res)
            except Exception:
                pass
        else:
            res = res.split("\n")
            for param in res:
                try:
                    res = json.loads(param)
                except Exception:
                    continue
                if isinstance(res, dict): break
        if isinstance(res, dict):
            ret = [res.get(param_name, None) for param_name in self._param_dict]
        else:
            ret = [None] * len(self._param_dict)
        ret = package(ret)
        return ret

lazyllm.tools.QustionRewrite

Bases: ModuleBase

问题改写模块。

该模块使用语言模型对用户输入的问题进行改写,可根据输出格式选择返回字符串或列表。

__init__(self, base_model, rewrite_prompt="", formatter="str") 使用提示词和模型初始化问题改写模块。

Parameters:

  • base_model (Union[str, TrainableModule, OnlineChatModuleBase]) –

    问题改写所使用的模型路径或已初始化模型。

  • rewrite_prompt (str, default: '' ) –

    用户自定义的改写提示词。

  • formatter (str, default: 'str' ) –

    输出格式,可选 "str"(字符串)或 "list"(按行分割的列表)。

Examples:

>>> from lazyllm.components import QustionRewrite
>>> rewriter = QustionRewrite(base_model="chatglm", rewrite_prompt="请将问题改写为更适合检索的形式", formatter="list")
>>> result = rewriter("中国的最高山峰是什么?")
>>> print(result)
... ['中国的最高山峰是哪一座?', '中国海拔最高的山是什么?']
Source code in lazyllm/tools/actors/qustion_rewrite.py
class QustionRewrite(ModuleBase):
    """问题改写模块。

该模块使用语言模型对用户输入的问题进行改写,可根据输出格式选择返回字符串或列表。

`__init__(self, base_model, rewrite_prompt="", formatter="str")`
使用提示词和模型初始化问题改写模块。

Args:
    base_model (Union[str, TrainableModule, OnlineChatModuleBase]): 问题改写所使用的模型路径或已初始化模型。
    rewrite_prompt (str): 用户自定义的改写提示词。
    formatter (str): 输出格式,可选 "str"(字符串)或 "list"(按行分割的列表)。


Examples:
    >>> from lazyllm.components import QustionRewrite
    >>> rewriter = QustionRewrite(base_model="chatglm", rewrite_prompt="请将问题改写为更适合检索的形式", formatter="list")
    >>> result = rewriter("中国的最高山峰是什么?")
    >>> print(result)
    ... ['中国的最高山峰是哪一座?', '中国海拔最高的山是什么?']
    """
    def __init__(
        self,
        base_model: Union[str, TrainableModule, OnlineChatModuleBase],
        rewrite_prompt: str = "",
        formatter: str = "str",
    ):
        super().__init__()
        self._prompt = self.choose_prompt(rewrite_prompt).format(prompt=rewrite_prompt)
        if isinstance(base_model, str):
            self._m = TrainableModule(base_model).start().prompt(self._prompt)
        else:
            self._m = base_model.share(self._prompt)
        self.formatter = formatter

    def choose_prompt(self, prompt: str):
        """
根据输入提示的语言选择合适的提示模板。

此方法分析输入提示字符串并确定使用中文还是英文提示模板。它检查提示字符串中的每个字符,如果任何字符落在中文字符Unicode范围内(\u4e00-\u9fff),则返回中文提示模板;否则返回英文提示模板。

Args:
    prompt (str): 要分析语言检测的输入提示字符串。

Returns:
    str: 选定的提示模板字符串(中文或英文版本)。


Examples:

    >>> from lazyllm.tools.actors.qustion_rewrite import QustionRewrite

    # Example 1: English prompt (no Chinese characters)
    >>> rewriter = QustionRewrite("gpt-3.5-turbo")
    >>> prompt_template = rewriter.choose_prompt("How to implement machine learning?")
    >>> print("Template contains Chinese:", "中文" in prompt_template)
    Template contains Chinese: False

    # Example 2: Chinese prompt (contains Chinese characters)
    >>> prompt_template = rewriter.choose_prompt("如何实现机器学习?")
    >>> print("Template contains Chinese:", "中文" in prompt_template)
    Template contains Chinese: True

    # Example 3: Mixed language prompt (contains Chinese characters)
    >>> prompt_template = rewriter.choose_prompt("What is 机器学习?")
    >>> print("Template contains Chinese:", "中文" in prompt_template)
    Template contains Chinese: True
    """
        # Use chinese prompt if intent elements have chinese character, otherwise use english version
        for ele in prompt:
            # chinese unicode range
            if "\u4e00" <= ele <= "\u9fff":
                return ch_qustion_rewrite_prompt
        return en_qustion_rewrite_prompt

    def forward(self, *args, **kw):
        res = self._m(*args, **kw)
        if self.formatter == "list":
            return list(filter(None, res.split('\n')))
        else:
            return res

choose_prompt(prompt)

根据输入提示的语言选择合适的提示模板。

此方法分析输入提示字符串并确定使用中文还是英文提示模板。它检查提示字符串中的每个字符,如果任何字符落在中文字符Unicode范围内(一-鿿),则返回中文提示模板;否则返回英文提示模板。

Parameters:

  • prompt (str) –

    要分析语言检测的输入提示字符串。

Returns:

  • str

    选定的提示模板字符串(中文或英文版本)。

Examples:

>>> from lazyllm.tools.actors.qustion_rewrite import QustionRewrite

# Example 1: English prompt (no Chinese characters)
>>> rewriter = QustionRewrite("gpt-3.5-turbo")
>>> prompt_template = rewriter.choose_prompt("How to implement machine learning?")
>>> print("Template contains Chinese:", "中文" in prompt_template)
Template contains Chinese: False

# Example 2: Chinese prompt (contains Chinese characters)
>>> prompt_template = rewriter.choose_prompt("如何实现机器学习?")
>>> print("Template contains Chinese:", "中文" in prompt_template)
Template contains Chinese: True

# Example 3: Mixed language prompt (contains Chinese characters)
>>> prompt_template = rewriter.choose_prompt("What is 机器学习?")
>>> print("Template contains Chinese:", "中文" in prompt_template)
Template contains Chinese: True
Source code in lazyllm/tools/actors/qustion_rewrite.py
    def choose_prompt(self, prompt: str):
        """
根据输入提示的语言选择合适的提示模板。

此方法分析输入提示字符串并确定使用中文还是英文提示模板。它检查提示字符串中的每个字符,如果任何字符落在中文字符Unicode范围内(\u4e00-\u9fff),则返回中文提示模板;否则返回英文提示模板。

Args:
    prompt (str): 要分析语言检测的输入提示字符串。

Returns:
    str: 选定的提示模板字符串(中文或英文版本)。


Examples:

    >>> from lazyllm.tools.actors.qustion_rewrite import QustionRewrite

    # Example 1: English prompt (no Chinese characters)
    >>> rewriter = QustionRewrite("gpt-3.5-turbo")
    >>> prompt_template = rewriter.choose_prompt("How to implement machine learning?")
    >>> print("Template contains Chinese:", "中文" in prompt_template)
    Template contains Chinese: False

    # Example 2: Chinese prompt (contains Chinese characters)
    >>> prompt_template = rewriter.choose_prompt("如何实现机器学习?")
    >>> print("Template contains Chinese:", "中文" in prompt_template)
    Template contains Chinese: True

    # Example 3: Mixed language prompt (contains Chinese characters)
    >>> prompt_template = rewriter.choose_prompt("What is 机器学习?")
    >>> print("Template contains Chinese:", "中文" in prompt_template)
    Template contains Chinese: True
    """
        # Use chinese prompt if intent elements have chinese character, otherwise use english version
        for ele in prompt:
            # chinese unicode range
            if "\u4e00" <= ele <= "\u9fff":
                return ch_qustion_rewrite_prompt
        return en_qustion_rewrite_prompt

lazyllm.tools.agent.toolsManager.ToolManager

Bases: ModuleBase

ToolManager是一个工具管理类,用于提供工具信息和工具调用给function call。

此管理类构造时需要传入工具名字符串列表。此处工具名可以是LazyLLM提供的,也可以是用户自定义的,如果是用户自定义的,首先需要注册进LazyLLM中才可以使用。在注册时直接使用 fc_register 注册器,该注册器已经建立 tool group,所以使用该工具管理类时,所有函数都统一注册进 tool 分组即可。待注册的函数需要对函数参数进行注解,并且需要对函数增加功能描述,以及参数类型和作用描述。以方便工具管理类能对函数解析传给LLM使用。

Parameters:

  • tools (List[str]) –

    工具名称字符串列表。

  • return_trace (bool, default: False ) –

    是否返回中间步骤和工具调用信息。

  • stream (bool) –

    是否以流式方式输出规划和解决过程。

Examples:

>>> from lazyllm.tools import ToolManager, fc_register
>>> import json
>>> from typing import Literal
>>> @fc_register("tool")
>>> def get_current_weather(location: str, unit: Literal["fahrenheit", "celsius"]="fahrenheit"):
...     '''
...     Get the current weather in a given location
...
...     Args:
...         location (str): The city and state, e.g. San Francisco, CA.
...         unit (str): The temperature unit to use. Infer this from the users location.
...     '''
...     if 'tokyo' in location.lower():
...         return json.dumps({'location': 'Tokyo', 'temperature': '10', 'unit': 'celsius'})
...     elif 'san francisco' in location.lower():
...         return json.dumps({'location': 'San Francisco', 'temperature': '72', 'unit': 'fahrenheit'})
...     elif 'paris' in location.lower():
...         return json.dumps({'location': 'Paris', 'temperature': '22', 'unit': 'celsius'})
...     elif 'beijing' in location.lower():
...         return json.dumps({'location': 'Beijing', 'temperature': '90', 'unit': 'fahrenheit'})
...     else:
...         return json.dumps({'location': location, 'temperature': 'unknown'})
...
>>> @fc_register("tool")
>>> def get_n_day_weather_forecast(location: str, num_days: int, unit: Literal["celsius", "fahrenheit"]='fahrenheit'):
...     '''
...     Get an N-day weather forecast
...
...     Args:
...         location (str): The city and state, e.g. San Francisco, CA.
...         num_days (int): The number of days to forecast.
...         unit (Literal['celsius', 'fahrenheit']): The temperature unit to use. Infer this from the users location.
...     '''
...     if 'tokyo' in location.lower():
...         return json.dumps({'location': 'Tokyo', 'temperature': '10', 'unit': 'celsius', "num_days": num_days})
...     elif 'san francisco' in location.lower():
...         return json.dumps({'location': 'San Francisco', 'temperature': '75', 'unit': 'fahrenheit', "num_days": num_days})
...     elif 'paris' in location.lower():
...         return json.dumps({'location': 'Paris', 'temperature': '25', 'unit': 'celsius', "num_days": num_days})
...     elif 'beijing' in location.lower():
...         return json.dumps({'location': 'Beijing', 'temperature': '85', 'unit': 'fahrenheit', "num_days": num_days})
...     else:
...         return json.dumps({'location': location, 'temperature': 'unknown'})
...
>>> tools = ["get_current_weather", "get_n_day_weather_forecast"]
>>> tm = ToolManager(tools)
>>> print(tm([{'name': 'get_n_day_weather_forecast', 'arguments': {'location': 'Beijing', 'num_days': 3}}])[0])
'{"location": "Beijing", "temperature": "85", "unit": "fahrenheit", "num_days": 3}'
Source code in lazyllm/tools/agent/toolsManager.py
class ToolManager(ModuleBase):
    """ToolManager是一个工具管理类,用于提供工具信息和工具调用给function call。

此管理类构造时需要传入工具名字符串列表。此处工具名可以是LazyLLM提供的,也可以是用户自定义的,如果是用户自定义的,首先需要注册进LazyLLM中才可以使用。在注册时直接使用 `fc_register` 注册器,该注册器已经建立 `tool` group,所以使用该工具管理类时,所有函数都统一注册进 `tool` 分组即可。待注册的函数需要对函数参数进行注解,并且需要对函数增加功能描述,以及参数类型和作用描述。以方便工具管理类能对函数解析传给LLM使用。

Args:
    tools (List[str]): 工具名称字符串列表。
    return_trace (bool): 是否返回中间步骤和工具调用信息。
    stream (bool): 是否以流式方式输出规划和解决过程。


Examples:
    >>> from lazyllm.tools import ToolManager, fc_register
    >>> import json
    >>> from typing import Literal
    >>> @fc_register("tool")
    >>> def get_current_weather(location: str, unit: Literal["fahrenheit", "celsius"]="fahrenheit"):
    ...     '''
    ...     Get the current weather in a given location
    ...
    ...     Args:
    ...         location (str): The city and state, e.g. San Francisco, CA.
    ...         unit (str): The temperature unit to use. Infer this from the users location.
    ...     '''
    ...     if 'tokyo' in location.lower():
    ...         return json.dumps({'location': 'Tokyo', 'temperature': '10', 'unit': 'celsius'})
    ...     elif 'san francisco' in location.lower():
    ...         return json.dumps({'location': 'San Francisco', 'temperature': '72', 'unit': 'fahrenheit'})
    ...     elif 'paris' in location.lower():
    ...         return json.dumps({'location': 'Paris', 'temperature': '22', 'unit': 'celsius'})
    ...     elif 'beijing' in location.lower():
    ...         return json.dumps({'location': 'Beijing', 'temperature': '90', 'unit': 'fahrenheit'})
    ...     else:
    ...         return json.dumps({'location': location, 'temperature': 'unknown'})
    ...
    >>> @fc_register("tool")
    >>> def get_n_day_weather_forecast(location: str, num_days: int, unit: Literal["celsius", "fahrenheit"]='fahrenheit'):
    ...     '''
    ...     Get an N-day weather forecast
    ...
    ...     Args:
    ...         location (str): The city and state, e.g. San Francisco, CA.
    ...         num_days (int): The number of days to forecast.
    ...         unit (Literal['celsius', 'fahrenheit']): The temperature unit to use. Infer this from the users location.
    ...     '''
    ...     if 'tokyo' in location.lower():
    ...         return json.dumps({'location': 'Tokyo', 'temperature': '10', 'unit': 'celsius', "num_days": num_days})
    ...     elif 'san francisco' in location.lower():
    ...         return json.dumps({'location': 'San Francisco', 'temperature': '75', 'unit': 'fahrenheit', "num_days": num_days})
    ...     elif 'paris' in location.lower():
    ...         return json.dumps({'location': 'Paris', 'temperature': '25', 'unit': 'celsius', "num_days": num_days})
    ...     elif 'beijing' in location.lower():
    ...         return json.dumps({'location': 'Beijing', 'temperature': '85', 'unit': 'fahrenheit', "num_days": num_days})
    ...     else:
    ...         return json.dumps({'location': location, 'temperature': 'unknown'})
    ...
    >>> tools = ["get_current_weather", "get_n_day_weather_forecast"]
    >>> tm = ToolManager(tools)
    >>> print(tm([{'name': 'get_n_day_weather_forecast', 'arguments': {'location': 'Beijing', 'num_days': 3}}])[0])
    '{"location": "Beijing", "temperature": "85", "unit": "fahrenheit", "num_days": 3}'
    """
    def __init__(self, tools: List[Union[str, Callable]], return_trace: bool = False):
        super().__init__(return_trace=return_trace)
        self._tools = self._load_tools(tools)
        self._format_tools()
        self._tools_desc = self._transform_to_openai_function()

    def _load_tools(self, tools: List[Union[str, Callable]]):
        if "tmp_tool" not in LazyLLMRegisterMetaClass.all_clses:
            register.new_group('tmp_tool')

        _tools = []
        for element in tools:
            if isinstance(element, str):
                _tools.append(getattr(lazyllm.tool, element)())
            elif isinstance(element, Callable):
                # just to convert `element` to the internal type in `Register`
                register('tmp_tool')(element)
                _tools.append(getattr(lazyllm.tmp_tool, element.__name__)())
                lazyllm.tmp_tool.remove(element.__name__)

        return _tools

    @property
    def all_tools(self):
        return self._tools

    @property
    def tools_description(self):
        return self._tools_desc

    @property
    def tools_info(self):
        return self._tool_call

    def _validate_tool(self, tool_name: str, tool_arguments: Dict[str, Any]):
        tool = self._tool_call.get(tool_name)
        if not tool:
            LOG.error(f'cannot find tool named [{tool_name}]')
            return False

        return tool.validate_parameters(tool_arguments)

    def _format_tools(self):
        if isinstance(self._tools, List):
            self._tool_call = {tool.name: tool for tool in self._tools}

    @staticmethod
    def _gen_args_info_from_moduletool_and_docstring(tool, parsed_docstring):
        """
        returns a dict of param names containing at least
          1. `type`
          2. `description` of params

        for example:
            args = {
                "foo": {
                    "enum": ["baz", "bar"],
                    "type": "string",
                    "description": "a string",
                },
                "bar": {
                    "type": "integer",
                    "description": "an integer",
                }
            }
        """
        tool_args = tool.args
        assert len(tool_args) == len(parsed_docstring.params), ("The parameter description and the actual "
                                                                "number of input parameters are inconsistent.")

        args_description = {}
        for param in parsed_docstring.params:
            args_description[param.arg_name] = param.description

        args = {}
        for k, v in tool_args.items():
            val = copy.deepcopy(v)
            val.pop("title", None)
            val.pop("default", None)
            args[k] = val if val else {"type": "string"}
            desc = args_description.get(k, None)
            if desc:
                args[k].update({"description": desc})
            else:
                raise ValueError(f"The actual input parameter '{k}' is not found "
                                 f"in the parameter description of tool '{tool.name}'.")
        return args

    def _transform_to_openai_function(self):
        if not isinstance(self._tools, List):
            raise TypeError(f"The tools type should be List instead of {type(self._tools)}")

        format_tools = []
        for tool in self._tools:
            try:
                parsed_docstring = docstring_parser.parse(tool.description)
                args = self._gen_args_info_from_moduletool_and_docstring(tool, parsed_docstring)
                required_arg_list = tool.params_schema.model_json_schema().get("required", [])
                func = {
                    "type": "function",
                    "function": {
                        "name": tool.name,
                        "description": parsed_docstring.short_description,
                        "parameters": {
                            "type": "object",
                            "properties": args,
                            "required": required_arg_list,
                        }
                    }
                }
                format_tools.append(func)
            except Exception:
                typehints_template = """
                def myfunc(arg1: str, arg2: Dict[str, Any], arg3: Literal["aaa", "bbb", "ccc"]="aaa"):
                    '''
                    Function description ...

                    Args:
                        arg1 (str): arg1 description.
                        arg2 (Dict[str, Any]): arg2 description
                        arg3 (Literal["aaa", "bbb", "ccc"]): arg3 description
                    '''
                """
                raise TypeError("Function description must include function description and "
                                f"parameter description, the format is as follows: {typehints_template}")
        return format_tools

    def forward(self, tools: Union[Dict[str, Any], List[Dict[str, Any]]], verbose: bool = False):
        tool_calls = [tools,] if isinstance(tools, dict) else tools
        tool_calls = [{"name": tool['name'], "arguments": json.loads(tool['arguments'])
                      if isinstance(tool['arguments'], str) else tool['arguments']} for tool in tool_calls]
        output = []
        flag_val = [True if self._validate_tool(tool['name'], tool['arguments']) else False for tool in tool_calls]
        tool_inputs = [tool_calls[idx]['arguments'] for idx, val in enumerate(flag_val) if val]
        tools = [self._tool_call[tool_calls[idx]['name']] for idx, val in enumerate(flag_val) if val]
        tool_diverter = lazyllm.diverter(tuple(tools))
        rets = tool_diverter(tuple(tool_inputs))
        res = iter(rets)
        rets = [next(res) if ele else None for ele in flag_val]
        for idx, tool in enumerate(tool_calls):
            if flag_val[idx]:
                ret = rets[idx]
                output.append(json.dumps(ret, ensure_ascii=False) if not isinstance(ret, str) else ret)
            else:
                output.append(f"{tool} parameters error.")

        return output

lazyllm.tools.ModuleTool

Bases: ModuleBase

用于构建工具模块的基类。

该类封装了函数签名和文档字符串的自动解析逻辑,可生成标准化的参数模式(基于 pydantic),并对输入进行校验和工具调用的标准封装。

__init__(self, verbose=False, return_trace=True) 初始化工具模块。

Parameters:

  • verbose (bool, default: False ) –

    是否在执行过程中输出详细日志。

  • return_trace (bool, default: True ) –

    是否在结果中保留中间执行痕迹。

Examples:

>>> from lazyllm.components import ModuleTool
>>> class AddTool(ModuleTool):
...     def apply(self, a: int, b: int) -> int:
...         '''Add two integers.
...         
...         Args:
...             a (int): First number.
...             b (int): Second number.
...         
...         Returns:
...             int: The sum of a and b.
...         '''
...         return a + b
>>> tool = AddTool()
>>> result = tool({'a': 3, 'b': 5})
>>> print(result)
8
Source code in lazyllm/tools/agent/toolsManager.py
class ModuleTool(ModuleBase, metaclass=LazyLLMRegisterMetaClass):
    """用于构建工具模块的基类。

该类封装了函数签名和文档字符串的自动解析逻辑,可生成标准化的参数模式(基于 pydantic),并对输入进行校验和工具调用的标准封装。

`__init__(self, verbose=False, return_trace=True)`
初始化工具模块。

Args:
    verbose (bool): 是否在执行过程中输出详细日志。
    return_trace (bool): 是否在结果中保留中间执行痕迹。


Examples:

    >>> from lazyllm.components import ModuleTool
    >>> class AddTool(ModuleTool):
    ...     def apply(self, a: int, b: int) -> int:
    ...         '''Add two integers.
    ...         
    ...         Args:
    ...             a (int): First number.
    ...             b (int): Second number.
    ...         
    ...         Returns:
    ...             int: The sum of a and b.
    ...         '''
    ...         return a + b
    >>> tool = AddTool()
    >>> result = tool({'a': 3, 'b': 5})
    >>> print(result)
    8
    """
    def __init__(self, verbose: bool = False, return_trace: bool = True):
        super().__init__(return_trace=return_trace)
        self._verbose = verbose
        self._name = self.apply.__name__\
            if hasattr(self.apply, "__name__") and self.apply.__name__ is not None\
            else (_ for _ in ()).throw(ValueError("Function must have a name."))
        self._description = self.apply.__doc__\
            if hasattr(self.apply, "__doc__") and self.apply.__doc__ is not None\
            else (_ for _ in ()).throw(ValueError("Function must have a docstring"))
        # strip space(s) and newlines before and after docstring, as RewooAgent requires
        self._description = self._description.strip(' \n')

        self._params_schema = self._load_function_schema(self.__class__.apply)

    def _load_function_schema(self, func: Callable) -> Type[BaseModel]:
        parsed_docstring = docstring_parser.parse(self._description)
        func_str_from_doc = _gen_empty_func_str_from_parsed_docstring(parsed_docstring)
        func_from_doc = _gen_func_from_str(func_str_from_doc, self._description)
        func_from_doc.__name__ = func.__name__
        doc_type_hints = get_type_hints(func_from_doc, globals(), locals())

        func_type_hints = get_type_hints(func, globals(), locals())

        _check_return_type_is_the_same(doc_type_hints, func_type_hints)

        signature = inspect.signature(func)
        has_var_args = False
        for _, param in signature.parameters.items():
            if param.kind == inspect.Parameter.VAR_POSITIONAL or\
               param.kind == inspect.Parameter.VAR_KEYWORD:
                has_var_args = True
                break

        if has_var_args:
            # we cannot get type hints from var args, so we get them from docstring
            self._type_hints = doc_type_hints
            signature = inspect.signature(func_from_doc)
        else:
            self._type_hints = func_type_hints
            # accomplish type_hints from docstring
            for name, type in doc_type_hints.items():
                self._type_hints.setdefault(name, type)

        self._return_type = self._type_hints.get('return') if self._type_hints else None

        fields = {
            name: (self._type_hints.get(name, Any), param.default
                   if param.default is not inspect.Parameter.empty
                   else ...)
            for name, param in signature.parameters.items()
        }

        return create_model(self._name, **fields)

    @property
    def name(self):
        return self._name

    @property
    def description(self):
        return self._description

    @property
    def params_schema(self) -> Type[BaseModel]:
        return self._params_schema

    @property
    def args(self) -> Dict[str, Any]:
        return self._params_schema.model_json_schema()["properties"]

    @property
    def required_args(self) -> Set[str]:
        return set(self._params_schema.model_json_schema().get("required", []))

    def apply(self, *args: Any, **kwargs: Any) -> Any:
        """
抽象方法,需在子类中实现具体逻辑。

此方法应根据传入的参数执行特定任务。

Raises:
    NotImplementedError: 如果未在子类中重写该方法。
"""
        raise NotImplementedError("Implement apply function in subclass")

    def _validate_input(self, tool_input: Dict[str, Any]) -> Dict[str, Any]:
        input_params = self._params_schema
        if isinstance(tool_input, dict):
            if input_params is not None:
                ret = input_params.model_validate(tool_input)
                return {key: getattr(ret, key) for key in ret.model_dump().keys() if key in tool_input}
            return tool_input
        elif isinstance(tool_input, str):
            if input_params is not None:
                key = next(iter(input_params.model_fields.keys()))
                input_params.model_validate({key: tool_input})
                arg_type = self._type_hints.get(key)
                if arg_type:
                    return {key: arg_type(tool_input)}
                return {key: tool_input}

            if len(self._type_hints) != 1:
                return tool_input
            arg_type = self._type_hints.values()[0]
            return arg_type(tool_input)
        else:
            raise TypeError(f"tool_input {tool_input} only supports dict and str.")

    def validate_parameters(self, arguments: Dict[str, Any]) -> bool:
        """
验证参数是否满足所需条件。

此方法会检查参数字典是否包含所有必须字段,并尝试进一步进行格式验证。

Args:
    arguments (Dict[str, Any]): 传入的参数字典。

Returns:
    bool: 若参数合法且完整,返回 True;否则返回 False。
"""
        if len(self.required_args.difference(set(arguments.keys()))) == 0:
            # contains all required parameters
            try:
                self._validate_input(arguments)
                return True
            except ValidationError:
                return False
        return False

    def forward(self, tool_input: Union[str, Dict[str, Any]], verbose: bool = False) -> Any:
        val_input = self._validate_input(tool_input)
        if isinstance(val_input, dict):
            ret = self.apply(**val_input)
        else:
            ret = self.apply(val_input)
        if verbose or self._verbose:
            lazyllm.LOG.debug(f"The output of tool {self.name} is {ret}")

        return ret

apply(*args, **kwargs)

抽象方法,需在子类中实现具体逻辑。

此方法应根据传入的参数执行特定任务。

Raises:

  • NotImplementedError

    如果未在子类中重写该方法。

Source code in lazyllm/tools/agent/toolsManager.py
    def apply(self, *args: Any, **kwargs: Any) -> Any:
        """
抽象方法,需在子类中实现具体逻辑。

此方法应根据传入的参数执行特定任务。

Raises:
    NotImplementedError: 如果未在子类中重写该方法。
"""
        raise NotImplementedError("Implement apply function in subclass")

validate_parameters(arguments)

验证参数是否满足所需条件。

此方法会检查参数字典是否包含所有必须字段,并尝试进一步进行格式验证。

Parameters:

  • arguments (Dict[str, Any]) –

    传入的参数字典。

Returns:

  • bool ( bool ) –

    若参数合法且完整,返回 True;否则返回 False。

Source code in lazyllm/tools/agent/toolsManager.py
    def validate_parameters(self, arguments: Dict[str, Any]) -> bool:
        """
验证参数是否满足所需条件。

此方法会检查参数字典是否包含所有必须字段,并尝试进一步进行格式验证。

Args:
    arguments (Dict[str, Any]): 传入的参数字典。

Returns:
    bool: 若参数合法且完整,返回 True;否则返回 False。
"""
        if len(self.required_args.difference(set(arguments.keys()))) == 0:
            # contains all required parameters
            try:
                self._validate_input(arguments)
                return True
            except ValidationError:
                return False
        return False

lazyllm.tools.FunctionCall

Bases: ModuleBase

FunctionCall是单轮工具调用类。当LLM自身信息不足以回答用户问题,需要结合外部工具获取辅助信息时,调用此类。
若LLM输出需要调用工具,则执行工具调用并返回调用结果;输出结果为List类型,包含当前轮的输入、模型输出和工具输出。
若不需工具调用,则直接返回LLM输出结果,输出为字符串类型。

Parameters:

  • llm (ModuleBase) –

    使用的LLM实例,支持TrainableModule或OnlineChatModule。

  • tools (List[Union[str, Callable]]) –

    LLM可调用的工具名称或Callable对象列表。

  • return_trace (Optional[bool], default: False ) –

    是否返回调用轨迹,默认为False。

  • stream (Optional[bool], default: False ) –

    是否启用流式输出,默认为False。

  • _prompt (Optional[str], default: None ) –

    自定义工具调用提示语,默认根据llm类型自动设置。

注意:tools中的工具需包含__doc__字段,且须遵循Google Python Style规范说明用途与参数。

Examples:

>>> import lazyllm
>>> from lazyllm.tools import fc_register, FunctionCall
>>> import json
>>> from typing import Literal
>>> @fc_register("tool")
>>> def get_current_weather(location: str, unit: Literal["fahrenheit", "celsius"] = 'fahrenheit'):
...     '''
...     Get the current weather in a given location
...
...     Args:
...         location (str): The city and state, e.g. San Francisco, CA.
...         unit (str): The temperature unit to use. Infer this from the users location.
...     '''
...     if 'tokyo' in location.lower():
...         return json.dumps({'location': 'Tokyo', 'temperature': '10', 'unit': 'celsius'})
...     elif 'san francisco' in location.lower():
...         return json.dumps({'location': 'San Francisco', 'temperature': '72', 'unit': 'fahrenheit'})
...     elif 'paris' in location.lower():
...         return json.dumps({'location': 'Paris', 'temperature': '22', 'unit': 'celsius'})
...     else:
...         return json.dumps({'location': location, 'temperature': 'unknown'})
...
>>> @fc_register("tool")
>>> def get_n_day_weather_forecast(location: str, num_days: int, unit: Literal["celsius", "fahrenheit"] = 'fahrenheit'):
...     '''
...     Get an N-day weather forecast
...
...     Args:
...         location (str): The city and state, e.g. San Francisco, CA.
...         num_days (int): The number of days to forecast.
...         unit (Literal['celsius', 'fahrenheit']): The temperature unit to use. Infer this from the users location.
...     '''
...     if 'tokyo' in location.lower():
...         return json.dumps({'location': 'Tokyo', 'temperature': '10', 'unit': 'celsius', "num_days": num_days})
...     elif 'san francisco' in location.lower():
...         return json.dumps({'location': 'San Francisco', 'temperature': '72', 'unit': 'fahrenheit', "num_days": num_days})
...     elif 'paris' in location.lower():
...         return json.dumps({'location': 'Paris', 'temperature': '22', 'unit': 'celsius', "num_days": num_days})
...     else:
...         return json.dumps({'location': location, 'temperature': 'unknown'})
...
>>> tools=["get_current_weather", "get_n_day_weather_forecast"]
>>> llm = lazyllm.TrainableModule("internlm2-chat-20b").start()  # or llm = lazyllm.OnlineChatModule("openai", stream=False)
>>> query = "What's the weather like today in celsius in Tokyo."
>>> fc = FunctionCall(llm, tools)
>>> ret = fc(query)
>>> print(ret)
["What's the weather like today in celsius in Tokyo.", {'role': 'assistant', 'content': '
', 'tool_calls': [{'id': 'da19cddac0584869879deb1315356d2a', 'type': 'function', 'function': {'name': 'get_current_weather', 'arguments': {'location': 'Tokyo', 'unit': 'celsius'}}}]}, [{'role': 'tool', 'content': '{"location": "Tokyo", "temperature": "10", "unit": "celsius"}', 'tool_call_id': 'da19cddac0584869879deb1315356d2a', 'name': 'get_current_weather'}]]
>>> query = "Hello"
>>> ret = fc(query)
>>> print(ret)
'Hello! How can I assist you today?'
Source code in lazyllm/tools/agent/functionCall.py
class FunctionCall(ModuleBase):
    """FunctionCall是单轮工具调用类。当LLM自身信息不足以回答用户问题,需要结合外部工具获取辅助信息时,调用此类。  
若LLM输出需要调用工具,则执行工具调用并返回调用结果;输出结果为List类型,包含当前轮的输入、模型输出和工具输出。  
若不需工具调用,则直接返回LLM输出结果,输出为字符串类型。

Args:
    llm (ModuleBase): 使用的LLM实例,支持TrainableModule或OnlineChatModule。
    tools (List[Union[str, Callable]]): LLM可调用的工具名称或Callable对象列表。
    return_trace (Optional[bool]): 是否返回调用轨迹,默认为False。
    stream (Optional[bool]): 是否启用流式输出,默认为False。
    _prompt (Optional[str]): 自定义工具调用提示语,默认根据llm类型自动设置。

注意:tools中的工具需包含`__doc__`字段,且须遵循[Google Python Style](https://google.github.io/styleguide/pyguide.html#38-comments-and-docstrings)规范说明用途与参数。


Examples:
    >>> import lazyllm
    >>> from lazyllm.tools import fc_register, FunctionCall
    >>> import json
    >>> from typing import Literal
    >>> @fc_register("tool")
    >>> def get_current_weather(location: str, unit: Literal["fahrenheit", "celsius"] = 'fahrenheit'):
    ...     '''
    ...     Get the current weather in a given location
    ...
    ...     Args:
    ...         location (str): The city and state, e.g. San Francisco, CA.
    ...         unit (str): The temperature unit to use. Infer this from the users location.
    ...     '''
    ...     if 'tokyo' in location.lower():
    ...         return json.dumps({'location': 'Tokyo', 'temperature': '10', 'unit': 'celsius'})
    ...     elif 'san francisco' in location.lower():
    ...         return json.dumps({'location': 'San Francisco', 'temperature': '72', 'unit': 'fahrenheit'})
    ...     elif 'paris' in location.lower():
    ...         return json.dumps({'location': 'Paris', 'temperature': '22', 'unit': 'celsius'})
    ...     else:
    ...         return json.dumps({'location': location, 'temperature': 'unknown'})
    ...
    >>> @fc_register("tool")
    >>> def get_n_day_weather_forecast(location: str, num_days: int, unit: Literal["celsius", "fahrenheit"] = 'fahrenheit'):
    ...     '''
    ...     Get an N-day weather forecast
    ...
    ...     Args:
    ...         location (str): The city and state, e.g. San Francisco, CA.
    ...         num_days (int): The number of days to forecast.
    ...         unit (Literal['celsius', 'fahrenheit']): The temperature unit to use. Infer this from the users location.
    ...     '''
    ...     if 'tokyo' in location.lower():
    ...         return json.dumps({'location': 'Tokyo', 'temperature': '10', 'unit': 'celsius', "num_days": num_days})
    ...     elif 'san francisco' in location.lower():
    ...         return json.dumps({'location': 'San Francisco', 'temperature': '72', 'unit': 'fahrenheit', "num_days": num_days})
    ...     elif 'paris' in location.lower():
    ...         return json.dumps({'location': 'Paris', 'temperature': '22', 'unit': 'celsius', "num_days": num_days})
    ...     else:
    ...         return json.dumps({'location': location, 'temperature': 'unknown'})
    ...
    >>> tools=["get_current_weather", "get_n_day_weather_forecast"]
    >>> llm = lazyllm.TrainableModule("internlm2-chat-20b").start()  # or llm = lazyllm.OnlineChatModule("openai", stream=False)
    >>> query = "What's the weather like today in celsius in Tokyo."
    >>> fc = FunctionCall(llm, tools)
    >>> ret = fc(query)
    >>> print(ret)
    ["What's the weather like today in celsius in Tokyo.", {'role': 'assistant', 'content': '
    ', 'tool_calls': [{'id': 'da19cddac0584869879deb1315356d2a', 'type': 'function', 'function': {'name': 'get_current_weather', 'arguments': {'location': 'Tokyo', 'unit': 'celsius'}}}]}, [{'role': 'tool', 'content': '{"location": "Tokyo", "temperature": "10", "unit": "celsius"}', 'tool_call_id': 'da19cddac0584869879deb1315356d2a', 'name': 'get_current_weather'}]]
    >>> query = "Hello"
    >>> ret = fc(query)
    >>> print(ret)
    'Hello! How can I assist you today?'
    """

    def __init__(self, llm, tools: List[Union[str, Callable]], *, return_trace: bool = False,
                 stream: bool = False, _prompt: str = None):
        super().__init__(return_trace=return_trace)
        if isinstance(llm, OnlineChatModule) and llm.series == "QWEN" and llm._stream is True:
            raise ValueError("The qwen platform does not currently support stream function calls.")
        if _prompt is None:
            _prompt = FC_PROMPT_ONLINE if isinstance(llm, OnlineChatModule) else FC_PROMPT_LOCAL

        self._tools_manager = ToolManager(tools, return_trace=return_trace)
        self._prompter = ChatPrompter(instruction=_prompt, tools=self._tools_manager.tools_description)\
            .pre_hook(function_call_hook)
        self._llm = llm.share(prompt=self._prompter, format=FunctionCallFormatter()).used_by(self._module_id)
        with pipeline() as self._impl:
            self._impl.ins = StreamResponse('Received instruction:', prefix_color=Color.yellow,
                                            color=Color.green, stream=stream)
            self._impl.m1 = self._llm
            self._impl.m2 = self._parser
            self._impl.dis = StreamResponse('Decision-making or result in this round:',
                                            prefix_color=Color.yellow, color=Color.green, stream=stream)
            self._impl.m3 = ifs(lambda x: isinstance(x, list),
                                pipeline(self._tools_manager, StreamResponse('Tool-Call result:',
                                         prefix_color=Color.yellow, color=Color.green, stream=stream)),
                                lambda out: out)
            self._impl.m4 = self._tool_post_action | bind(input=self._impl.input, llm_output=self._impl.m1)

    def _parser(self, llm_output: Union[str, List[Dict[str, Any]]]):
        LOG.debug(f"llm_output: {llm_output}")
        if isinstance(llm_output, list):
            res = []
            for item in llm_output:
                if isinstance(item, str):
                    continue
                arguments = item.get('function', {}).get('arguments', '')
                arguments = json.loads(arguments) if isinstance(arguments, str) else arguments
                res.append({"name": item.get('function', {}).get('name', ''), 'arguments': arguments})
            return res
        elif isinstance(llm_output, str):
            return llm_output
        else:
            raise TypeError(f"The {llm_output} type currently is only supports `list` and `str`,"
                            f" and does not support {type(llm_output)}.")

    def _tool_post_action(self, output: Union[str, List[str]], input: Union[str, List],
                          llm_output: List[Dict[str, Any]]):
        if isinstance(output, list):
            ret = []
            if isinstance(input, str):
                ret.append(input)
            elif isinstance(input, list):
                ret.append(input[-1])
            else:
                raise TypeError(f"The input type currently only supports `str` and `list`, "
                                f"and does not support {type(input)}.")

            content = "".join([item for item in llm_output if isinstance(item, str)])
            llm_output = [item for item in llm_output if not isinstance(item, str)]
            ret.append({"role": "assistant", "content": content, "tool_calls": llm_output})
            ret.append([{"role": "tool", "content": out, "tool_call_id": llm_output[idx]["id"],
                         "name": llm_output[idx]["function"]["name"]}
                        for idx, out in enumerate(output)])
            LOG.debug(f"functionCall result: {ret}")
            return ret
        elif isinstance(output, str):
            return output
        else:
            raise TypeError(f"The {output} type currently is only supports `list` and `str`,"
                            f" and does not support {type(output)}.")

    def forward(self, input: str, llm_chat_history: List[Dict[str, Any]] = None):
        globals['chat_history'].setdefault(self._llm._module_id, [])
        if llm_chat_history is not None:
            globals['chat_history'][self._llm._module_id] = llm_chat_history
        return self._impl(input)

lazyllm.tools.FunctionCallFormatter

Bases: JsonFormatter

用于解析函数调用结构消息的格式化器。

该类继承自 JsonFormatter,用于从包含工具调用信息的消息字符串中提取 JSON 结构,并在需要时通过全局分隔符拆分内容。

私有方法

_load(msg) 解析输入的消息字符串,提取其中的 JSON 格式的工具调用结构(如果存在)。

Examples:

>>> from lazyllm.components import FunctionCallFormatter
>>> formatter = FunctionCallFormatter()
>>> msg = "Please call this tool. <TOOL> [{"name": "search", "args": {"query": "weather"}}]"
>>> result = formatter._load(msg)
>>> print(result)
... [{'name': 'search', 'args': {'query': 'weather'}}, 'Please call this tool. ']
Source code in lazyllm/tools/agent/functionCallFormatter.py
class FunctionCallFormatter(JsonFormatter):
    """用于解析函数调用结构消息的格式化器。

该类继承自 `JsonFormatter`,用于从包含工具调用信息的消息字符串中提取 JSON 结构,并在需要时通过全局分隔符拆分内容。

私有方法:
    _load(msg)
        解析输入的消息字符串,提取其中的 JSON 格式的工具调用结构(如果存在)。


Examples:
    >>> from lazyllm.components import FunctionCallFormatter
    >>> formatter = FunctionCallFormatter()
    >>> msg = "Please call this tool. <TOOL> [{\"name\": \"search\", \"args\": {\"query\": \"weather\"}}]"
    >>> result = formatter._load(msg)
    >>> print(result)
    ... [{'name': 'search', 'args': {'query': 'weather'}}, 'Please call this tool. ']
    """
    def _load(self, msg: str):
        if "{" not in msg:
            return msg
        if globals['tool_delimiter'] in msg:
            content, msg = msg.split(globals['tool_delimiter'])
            assert msg.count("{") == msg.count("}"), f"{msg} is not a valid json string."
            try:
                json_strs = json.loads(msg)
                res = []
                for json_str in json_strs:
                    res.append(json_str)
                if content:
                    res.append(content)
                return res
            except Exception:
                return msg

        return msg

lazyllm.tools.FunctionCallAgent

Bases: ModuleBase

FunctionCallAgent是一个使用工具调用方式进行完整工具调用的代理,即回答用户问题时,LLM如果需要通过工具获取外部知识,就会调用工具,并将工具的返回结果反馈给LLM,最后由LLM进行汇总输出。

Parameters:

  • llm (ModuleBase) –

    要使用的LLM,可以是TrainableModule或OnlineChatModule。

  • tools (List[str]) –

    LLM 使用的工具名称列表。

  • max_retries (int, default: 5 ) –

    工具调用迭代的最大次数。默认值为5。

Examples:

>>> import lazyllm
>>> from lazyllm.tools import fc_register, FunctionCallAgent
>>> import json
>>> from typing import Literal
>>> @fc_register("tool")
>>> def get_current_weather(location: str, unit: Literal["fahrenheit", "celsius"]='fahrenheit'):
...     '''
...     Get the current weather in a given location
...
...     Args:
...         location (str): The city and state, e.g. San Francisco, CA.
...         unit (str): The temperature unit to use. Infer this from the users location.
...     '''
...     if 'tokyo' in location.lower():
...         return json.dumps({'location': 'Tokyo', 'temperature': '10', 'unit': 'celsius'})
...     elif 'san francisco' in location.lower():
...         return json.dumps({'location': 'San Francisco', 'temperature': '72', 'unit': 'fahrenheit'})
...     elif 'paris' in location.lower():
...         return json.dumps({'location': 'Paris', 'temperature': '22', 'unit': 'celsius'})
...     elif 'beijing' in location.lower():
...         return json.dumps({'location': 'Beijing', 'temperature': '90', 'unit': 'Fahrenheit'})
...     else:
...         return json.dumps({'location': location, 'temperature': 'unknown'})
...
>>> @fc_register("tool")
>>> def get_n_day_weather_forecast(location: str, num_days: int, unit: Literal["celsius", "fahrenheit"]='fahrenheit'):
...     '''
...     Get an N-day weather forecast
...
...     Args:
...         location (str): The city and state, e.g. San Francisco, CA.
...         num_days (int): The number of days to forecast.
...         unit (Literal['celsius', 'fahrenheit']): The temperature unit to use. Infer this from the users location.
...     '''
...     if 'tokyo' in location.lower():
...         return json.dumps({'location': 'Tokyo', 'temperature': '10', 'unit': 'celsius', "num_days": num_days})
...     elif 'san francisco' in location.lower():
...         return json.dumps({'location': 'San Francisco', 'temperature': '75', 'unit': 'fahrenheit', "num_days": num_days})
...     elif 'paris' in location.lower():
...         return json.dumps({'location': 'Paris', 'temperature': '25', 'unit': 'celsius', "num_days": num_days})
...     elif 'beijing' in location.lower():
...         return json.dumps({'location': 'Beijing', 'temperature': '85', 'unit': 'fahrenheit', "num_days": num_days})
...     else:
...         return json.dumps({'location': location, 'temperature': 'unknown'})
...
>>> tools = ['get_current_weather', 'get_n_day_weather_forecast']
>>> llm = lazyllm.TrainableModule("internlm2-chat-20b").start()  # or llm = lazyllm.OnlineChatModule(source="sensenova")
>>> agent = FunctionCallAgent(llm, tools)
>>> query = "What's the weather like today in celsius in Tokyo and Paris."
>>> res = agent(query)
>>> print(res)
'The current weather in Tokyo is 10 degrees Celsius, and in Paris, it is 22 degrees Celsius.'
>>> query = "Hello"
>>> res = agent(query)
>>> print(res)
'Hello! How can I assist you today?'
Source code in lazyllm/tools/agent/functionCall.py
class FunctionCallAgent(ModuleBase):
    """FunctionCallAgent是一个使用工具调用方式进行完整工具调用的代理,即回答用户问题时,LLM如果需要通过工具获取外部知识,就会调用工具,并将工具的返回结果反馈给LLM,最后由LLM进行汇总输出。

Args:
    llm (ModuleBase): 要使用的LLM,可以是TrainableModule或OnlineChatModule。
    tools (List[str]): LLM 使用的工具名称列表。
    max_retries (int): 工具调用迭代的最大次数。默认值为5。


Examples:
    >>> import lazyllm
    >>> from lazyllm.tools import fc_register, FunctionCallAgent
    >>> import json
    >>> from typing import Literal
    >>> @fc_register("tool")
    >>> def get_current_weather(location: str, unit: Literal["fahrenheit", "celsius"]='fahrenheit'):
    ...     '''
    ...     Get the current weather in a given location
    ...
    ...     Args:
    ...         location (str): The city and state, e.g. San Francisco, CA.
    ...         unit (str): The temperature unit to use. Infer this from the users location.
    ...     '''
    ...     if 'tokyo' in location.lower():
    ...         return json.dumps({'location': 'Tokyo', 'temperature': '10', 'unit': 'celsius'})
    ...     elif 'san francisco' in location.lower():
    ...         return json.dumps({'location': 'San Francisco', 'temperature': '72', 'unit': 'fahrenheit'})
    ...     elif 'paris' in location.lower():
    ...         return json.dumps({'location': 'Paris', 'temperature': '22', 'unit': 'celsius'})
    ...     elif 'beijing' in location.lower():
    ...         return json.dumps({'location': 'Beijing', 'temperature': '90', 'unit': 'Fahrenheit'})
    ...     else:
    ...         return json.dumps({'location': location, 'temperature': 'unknown'})
    ...
    >>> @fc_register("tool")
    >>> def get_n_day_weather_forecast(location: str, num_days: int, unit: Literal["celsius", "fahrenheit"]='fahrenheit'):
    ...     '''
    ...     Get an N-day weather forecast
    ...
    ...     Args:
    ...         location (str): The city and state, e.g. San Francisco, CA.
    ...         num_days (int): The number of days to forecast.
    ...         unit (Literal['celsius', 'fahrenheit']): The temperature unit to use. Infer this from the users location.
    ...     '''
    ...     if 'tokyo' in location.lower():
    ...         return json.dumps({'location': 'Tokyo', 'temperature': '10', 'unit': 'celsius', "num_days": num_days})
    ...     elif 'san francisco' in location.lower():
    ...         return json.dumps({'location': 'San Francisco', 'temperature': '75', 'unit': 'fahrenheit', "num_days": num_days})
    ...     elif 'paris' in location.lower():
    ...         return json.dumps({'location': 'Paris', 'temperature': '25', 'unit': 'celsius', "num_days": num_days})
    ...     elif 'beijing' in location.lower():
    ...         return json.dumps({'location': 'Beijing', 'temperature': '85', 'unit': 'fahrenheit', "num_days": num_days})
    ...     else:
    ...         return json.dumps({'location': location, 'temperature': 'unknown'})
    ...
    >>> tools = ['get_current_weather', 'get_n_day_weather_forecast']
    >>> llm = lazyllm.TrainableModule("internlm2-chat-20b").start()  # or llm = lazyllm.OnlineChatModule(source="sensenova")
    >>> agent = FunctionCallAgent(llm, tools)
    >>> query = "What's the weather like today in celsius in Tokyo and Paris."
    >>> res = agent(query)
    >>> print(res)
    'The current weather in Tokyo is 10 degrees Celsius, and in Paris, it is 22 degrees Celsius.'
    >>> query = "Hello"
    >>> res = agent(query)
    >>> print(res)
    'Hello! How can I assist you today?'
    """
    def __init__(self, llm, tools: List[str], max_retries: int = 5, return_trace: bool = False, stream: bool = False):
        super().__init__(return_trace=return_trace)
        self._max_retries = max_retries
        self._fc = FunctionCall(llm, tools, return_trace=return_trace, stream=stream)
        self._agent = loop(self._fc, stop_condition=lambda x: isinstance(x, str), count=self._max_retries)
        self._fc._llm.used_by(self._module_id)

    def forward(self, query: str, llm_chat_history: List[Dict[str, Any]] = None):
        ret = self._agent(query, llm_chat_history) if llm_chat_history is not None else self._agent(query)
        return ret if isinstance(ret, str) else (_ for _ in ()).throw(
            ValueError(f"After retrying {self._max_retries} times, the function call agent still "
                       "failed to call successfully."))

lazyllm.tools.ReactAgent

Bases: ModuleBase

ReactAgent是按照 Thought->Action->Observation->Thought...->Finish 的流程一步一步的通过LLM和工具调用来显示解决用户问题的步骤,以及最后给用户的答案。

Parameters:

  • llm (ModuleBase) –

    要使用的LLM,可以是TrainableModule或OnlineChatModule。

  • tools (List[str]) –

    LLM 使用的工具名称列表。

  • max_retries (int, default: 5 ) –

    工具调用迭代的最大次数。默认值为5。

  • return_trace (bool, default: False ) –

    是否返回中间步骤和工具调用信息。

  • stream (bool, default: False ) –

    是否以流式方式输出规划和解决过程。

Examples:

>>> import lazyllm
>>> from lazyllm.tools import fc_register, ReactAgent
>>> @fc_register("tool")
>>> def multiply_tool(a: int, b: int) -> int:
...     '''
...     Multiply two integers and return the result integer
...
...     Args:
...         a (int): multiplier
...         b (int): multiplier
...     '''
...     return a * b
...
>>> @fc_register("tool")
>>> def add_tool(a: int, b: int):
...     '''
...     Add two integers and returns the result integer
...
...     Args:
...         a (int): addend
...         b (int): addend
...     '''
...     return a + b
...
>>> tools = ["multiply_tool", "add_tool"]
>>> llm = lazyllm.TrainableModule("internlm2-chat-20b").start()   # or llm = lazyllm.OnlineChatModule(source="sensenova")
>>> agent = ReactAgent(llm, tools)
>>> query = "What is 20+(2*4)? Calculate step by step."
>>> res = agent(query)
>>> print(res)
'Answer: The result of 20+(2*4) is 28.'
Source code in lazyllm/tools/agent/reactAgent.py
class ReactAgent(ModuleBase):
    """ReactAgent是按照 `Thought->Action->Observation->Thought...->Finish` 的流程一步一步的通过LLM和工具调用来显示解决用户问题的步骤,以及最后给用户的答案。

Args:
    llm (ModuleBase): 要使用的LLM,可以是TrainableModule或OnlineChatModule。
    tools (List[str]): LLM 使用的工具名称列表。
    max_retries (int): 工具调用迭代的最大次数。默认值为5。
    return_trace (bool): 是否返回中间步骤和工具调用信息。
    stream (bool): 是否以流式方式输出规划和解决过程。


Examples:
    >>> import lazyllm
    >>> from lazyllm.tools import fc_register, ReactAgent
    >>> @fc_register("tool")
    >>> def multiply_tool(a: int, b: int) -> int:
    ...     '''
    ...     Multiply two integers and return the result integer
    ...
    ...     Args:
    ...         a (int): multiplier
    ...         b (int): multiplier
    ...     '''
    ...     return a * b
    ...
    >>> @fc_register("tool")
    >>> def add_tool(a: int, b: int):
    ...     '''
    ...     Add two integers and returns the result integer
    ...
    ...     Args:
    ...         a (int): addend
    ...         b (int): addend
    ...     '''
    ...     return a + b
    ...
    >>> tools = ["multiply_tool", "add_tool"]
    >>> llm = lazyllm.TrainableModule("internlm2-chat-20b").start()   # or llm = lazyllm.OnlineChatModule(source="sensenova")
    >>> agent = ReactAgent(llm, tools)
    >>> query = "What is 20+(2*4)? Calculate step by step."
    >>> res = agent(query)
    >>> print(res)
    'Answer: The result of 20+(2*4) is 28.'
    """
    def __init__(self, llm, tools: List[str], max_retries: int = 5, return_trace: bool = False,
                 prompt: str = None, stream: bool = False):
        super().__init__(return_trace=return_trace)
        self._max_retries = max_retries
        assert llm and tools, "llm and tools cannot be empty."

        if not prompt:
            prompt = INSTRUCTION.replace("{TOKENIZED_PROMPT}", WITHOUT_TOKEN_PROMPT if isinstance(llm, OnlineChatModule)
                                         else WITH_TOKEN_PROMPT)
            prompt = prompt.replace("{tool_names}", json.dumps([t.__name__ if callable(t) else t for t in tools],
                                                               ensure_ascii=False))
        self._agent = loop(FunctionCall(llm, tools, _prompt=prompt, return_trace=return_trace, stream=stream),
                           stop_condition=lambda x: isinstance(x, str), count=self._max_retries)

    def forward(self, query: str, llm_chat_history: List[Dict[str, Any]] = None):
        ret = self._agent(query, llm_chat_history) if llm_chat_history is not None else self._agent(query)
        return ret if isinstance(ret, str) else (_ for _ in ()).throw(ValueError(f"After retrying \
            {self._max_retries} times, the function call agent still failes to call successfully."))

lazyllm.tools.PlanAndSolveAgent

Bases: ModuleBase

PlanAndSolveAgent由两个组件组成,首先,由planner将整个任务分解为更小的子任务,然后由solver根据计划执行这些子任务,其中可能会涉及到工具调用,最后将答案返回给用户。

Parameters:

  • llm (ModuleBase, default: None ) –

    要使用的LLM,可以是TrainableModule或OnlineChatModule。和plan_llm、solve_llm互斥,要么设置llm(planner和solver公用一个LLM),要么设置plan_llm和solve_llm,或者只指定llm(用来设置planner)和solve_llm,其它情况均认为是无效的。

  • tools (List[str], default: [] ) –

    LLM使用的工具名称列表。

  • plan_llm (ModuleBase, default: None ) –

    planner要使用的LLM,可以是TrainableModule或OnlineChatModule。

  • solve_llm (ModuleBase, default: None ) –

    solver要使用的LLM,可以是TrainableModule或OnlineChatModule。

  • max_retries (int, default: 5 ) –

    工具调用迭代的最大次数。默认值为5。

  • return_trace (bool, default: False ) –

    是否返回中间步骤和工具调用信息。

  • stream (bool, default: False ) –

    是否以流式方式输出规划和解决过程。

Examples:

>>> import lazyllm
>>> from lazyllm.tools import fc_register, PlanAndSolveAgent
>>> @fc_register("tool")
>>> def multiply(a: int, b: int) -> int:
...     '''
...     Multiply two integers and return the result integer
...
...     Args:
...         a (int): multiplier
...         b (int): multiplier
...     '''
...     return a * b
...
>>> @fc_register("tool")
>>> def add(a: int, b: int):
...     '''
...     Add two integers and returns the result integer
...
...     Args:
...         a (int): addend
...         b (int): addend
...     '''
...     return a + b
...
>>> tools = ["multiply", "add"]
>>> llm = lazyllm.TrainableModule("internlm2-chat-20b").start()  # or llm = lazyllm.OnlineChatModule(source="sensenova")
>>> agent = PlanAndSolveAgent(llm, tools)
>>> query = "What is 20+(2*4)? Calculate step by step."
>>> res = agent(query)
>>> print(res)
'The final answer is 28.'
Source code in lazyllm/tools/agent/planAndSolveAgent.py
class PlanAndSolveAgent(ModuleBase):
    """PlanAndSolveAgent由两个组件组成,首先,由planner将整个任务分解为更小的子任务,然后由solver根据计划执行这些子任务,其中可能会涉及到工具调用,最后将答案返回给用户。

Args:
    llm (ModuleBase): 要使用的LLM,可以是TrainableModule或OnlineChatModule。和plan_llm、solve_llm互斥,要么设置llm(planner和solver公用一个LLM),要么设置plan_llm和solve_llm,或者只指定llm(用来设置planner)和solve_llm,其它情况均认为是无效的。
    tools (List[str]): LLM使用的工具名称列表。
    plan_llm (ModuleBase): planner要使用的LLM,可以是TrainableModule或OnlineChatModule。
    solve_llm (ModuleBase): solver要使用的LLM,可以是TrainableModule或OnlineChatModule。
    max_retries (int): 工具调用迭代的最大次数。默认值为5。
    return_trace (bool): 是否返回中间步骤和工具调用信息。
    stream (bool): 是否以流式方式输出规划和解决过程。


Examples:
    >>> import lazyllm
    >>> from lazyllm.tools import fc_register, PlanAndSolveAgent
    >>> @fc_register("tool")
    >>> def multiply(a: int, b: int) -> int:
    ...     '''
    ...     Multiply two integers and return the result integer
    ...
    ...     Args:
    ...         a (int): multiplier
    ...         b (int): multiplier
    ...     '''
    ...     return a * b
    ...
    >>> @fc_register("tool")
    >>> def add(a: int, b: int):
    ...     '''
    ...     Add two integers and returns the result integer
    ...
    ...     Args:
    ...         a (int): addend
    ...         b (int): addend
    ...     '''
    ...     return a + b
    ...
    >>> tools = ["multiply", "add"]
    >>> llm = lazyllm.TrainableModule("internlm2-chat-20b").start()  # or llm = lazyllm.OnlineChatModule(source="sensenova")
    >>> agent = PlanAndSolveAgent(llm, tools)
    >>> query = "What is 20+(2*4)? Calculate step by step."
    >>> res = agent(query)
    >>> print(res)
    'The final answer is 28.'
    """
    def __init__(self, llm: Union[ModuleBase, None] = None, tools: List[str] = [], *,  # noqa B006
                 plan_llm: Union[ModuleBase, None] = None, solve_llm: Union[ModuleBase, None] = None,
                 max_retries: int = 5, return_trace: bool = False, stream: bool = False):
        super().__init__(return_trace=return_trace)
        self._max_retries = max_retries
        assert (llm is None and plan_llm and solve_llm) or (llm and plan_llm is None), 'Either specify only llm \
               without specify plan and solve, or specify only plan and solve without specifying llm, or specify \
               both llm and solve. Other situations are not allowed.'
        assert tools, "tools cannot be empty."
        s = dict(prefix='I will give a plan first:\n', prefix_color=Color.blue, color=Color.green) if stream else False
        self._plan_llm = ((plan_llm or llm).share(prompt=ChatPrompter(instruction=PLANNER_PROMPT),
                                                  stream=s).used_by(self._module_id))
        self._solve_llm = (solve_llm or llm).share().used_by(self._module_id)
        self._tools = tools
        with pipeline() as self._agent:
            self._agent.plan = self._plan_llm
            self._agent.parse = (lambda text, query: package([], '', [v for v in re.split("\n\\s*\\d+\\. ", text)[1:]],
                                 query)) | bind(query=self._agent.input)
            with loop(stop_condition=lambda pre, res, steps, query: len(steps) == 0) as self._agent.lp:
                self._agent.lp.pre_action = self._pre_action
                self._agent.lp.solve = FunctionCallAgent(self._solve_llm, tools=self._tools,
                                                         return_trace=return_trace, stream=stream)
                self._agent.lp.post_action = self._post_action | bind(self._agent.lp.input[0][0], _0,
                                                                      self._agent.lp.input[0][2],
                                                                      self._agent.lp.input[0][3])

            self._agent.post_action = lambda pre, res, steps, query: res

    def _pre_action(self, pre_steps, response, steps, query):
        result = package(SOLVER_PROMPT.format(previous_steps="\n".join(pre_steps), current_step=steps[0],
                                              objective=query) + "input: " + response + "\n" + steps[0], [])
        return result

    def _post_action(self, pre_steps: List[str], response: str, steps: List[str], query: str):
        LOG.debug(f"current step: {steps[0]}, response: {response}")
        pre_steps.append(steps.pop(0))
        return package(pre_steps, response, steps, query)

    def forward(self, query: str):
        return self._agent(query)

lazyllm.tools.ReWOOAgent

Bases: ModuleBase

ReWOOAgent包含三个部分:Planner、Worker和Solver。其中,Planner使用可预见推理能力为复杂任务创建解决方案蓝图;Worker通过工具调用来与环境交互,并将实际证据或观察结果填充到指令中;Solver处理所有计划和证据以制定原始任务或问题的解决方案。

Parameters:

  • llm (ModuleBase, default: None ) –

    要使用的LLM,可以是TrainableModule或OnlineChatModule。和plan_llm、solve_llm互斥,要么设置llm(planner和solver公用一个LLM),要么设置plan_llm和solve_llm,或者只指定llm(用来设置planner)和solve_llm,其它情况均认为是无效的。

  • tools (List[str], default: [] ) –

    LLM使用的工具名称列表。

  • plan_llm (ModuleBase, default: None ) –

    planner要使用的LLM,可以是TrainableModule或OnlineChatModule。

  • solve_llm (ModuleBase, default: None ) –

    solver要使用的LLM,可以是TrainableModule或OnlineChatModule。

  • max_retries (int) –

    工具调用迭代的最大次数。默认值为5。

  • return_trace (bool, default: False ) –

    是否返回中间步骤和工具调用信息。

  • stream (bool, default: False ) –

    是否以流式方式输出规划和解决过程。

Examples:

>>> import lazyllm
>>> import wikipedia
>>> from lazyllm.tools import fc_register, ReWOOAgent
>>> @fc_register("tool")
>>> def WikipediaWorker(input: str):
...     '''
...     Worker that search for similar page contents from Wikipedia. Useful when you need to get holistic knowledge about people, places, companies, historical events, or other subjects. The response are long and might contain some irrelevant information. Input should be a search query.
...
...     Args:
...         input (str): search query.
...     '''
...     try:
...         evidence = wikipedia.page(input).content
...         evidence = evidence.split("\n\n")[0]
...     except wikipedia.PageError:
...         evidence = f"Could not find [{input}]. Similar: {wikipedia.search(input)}"
...     except wikipedia.DisambiguationError:
...         evidence = f"Could not find [{input}]. Similar: {wikipedia.search(input)}"
...     return evidence
...
>>> @fc_register("tool")
>>> def LLMWorker(input: str):
...     '''
...     A pretrained LLM like yourself. Useful when you need to act with general world knowledge and common sense. Prioritize it when you are confident in solving the problem yourself. Input can be any instruction.
...
...     Args:
...         input (str): instruction
...     '''
...     llm = lazyllm.OnlineChatModule(source="glm")
...     query = f"Respond in short directly with no extra words.\n\n{input}"
...     response = llm(query, llm_chat_history=[])
...     return response
...
>>> tools = ["WikipediaWorker", "LLMWorker"]
>>> llm = lazyllm.TrainableModule("GLM-4-9B-Chat").deploy_method(lazyllm.deploy.vllm).start()  # or llm = lazyllm.OnlineChatModule(source="sensenova")
>>> agent = ReWOOAgent(llm, tools)
>>> query = "What is the name of the cognac house that makes the main ingredient in The Hennchata?"
>>> res = agent(query)
>>> print(res)
'
Hennessy '
Source code in lazyllm/tools/agent/rewooAgent.py
class ReWOOAgent(ModuleBase):
    """ReWOOAgent包含三个部分:Planner、Worker和Solver。其中,Planner使用可预见推理能力为复杂任务创建解决方案蓝图;Worker通过工具调用来与环境交互,并将实际证据或观察结果填充到指令中;Solver处理所有计划和证据以制定原始任务或问题的解决方案。

Args:
    llm (ModuleBase): 要使用的LLM,可以是TrainableModule或OnlineChatModule。和plan_llm、solve_llm互斥,要么设置llm(planner和solver公用一个LLM),要么设置plan_llm和solve_llm,或者只指定llm(用来设置planner)和solve_llm,其它情况均认为是无效的。
    tools (List[str]): LLM使用的工具名称列表。
    plan_llm (ModuleBase): planner要使用的LLM,可以是TrainableModule或OnlineChatModule。
    solve_llm (ModuleBase): solver要使用的LLM,可以是TrainableModule或OnlineChatModule。
    max_retries (int): 工具调用迭代的最大次数。默认值为5。
    return_trace (bool): 是否返回中间步骤和工具调用信息。
    stream (bool): 是否以流式方式输出规划和解决过程。



Examples:
    >>> import lazyllm
    >>> import wikipedia
    >>> from lazyllm.tools import fc_register, ReWOOAgent
    >>> @fc_register("tool")
    >>> def WikipediaWorker(input: str):
    ...     '''
    ...     Worker that search for similar page contents from Wikipedia. Useful when you need to get holistic knowledge about people, places, companies, historical events, or other subjects. The response are long and might contain some irrelevant information. Input should be a search query.
    ...
    ...     Args:
    ...         input (str): search query.
    ...     '''
    ...     try:
    ...         evidence = wikipedia.page(input).content
    ...         evidence = evidence.split("\\n\\n")[0]
    ...     except wikipedia.PageError:
    ...         evidence = f"Could not find [{input}]. Similar: {wikipedia.search(input)}"
    ...     except wikipedia.DisambiguationError:
    ...         evidence = f"Could not find [{input}]. Similar: {wikipedia.search(input)}"
    ...     return evidence
    ...
    >>> @fc_register("tool")
    >>> def LLMWorker(input: str):
    ...     '''
    ...     A pretrained LLM like yourself. Useful when you need to act with general world knowledge and common sense. Prioritize it when you are confident in solving the problem yourself. Input can be any instruction.
    ...
    ...     Args:
    ...         input (str): instruction
    ...     '''
    ...     llm = lazyllm.OnlineChatModule(source="glm")
    ...     query = f"Respond in short directly with no extra words.\\n\\n{input}"
    ...     response = llm(query, llm_chat_history=[])
    ...     return response
    ...
    >>> tools = ["WikipediaWorker", "LLMWorker"]
    >>> llm = lazyllm.TrainableModule("GLM-4-9B-Chat").deploy_method(lazyllm.deploy.vllm).start()  # or llm = lazyllm.OnlineChatModule(source="sensenova")
    >>> agent = ReWOOAgent(llm, tools)
    >>> query = "What is the name of the cognac house that makes the main ingredient in The Hennchata?"
    >>> res = agent(query)
    >>> print(res)
    '
    Hennessy '
    """
    def __init__(self, llm: Union[ModuleBase, None] = None, tools: List[Union[str, Callable]] = [], *,  # noqa B006
                 plan_llm: Union[ModuleBase, None] = None, solve_llm: Union[ModuleBase, None] = None,
                 return_trace: bool = False, stream: bool = False):
        super().__init__(return_trace=return_trace)
        assert (llm is None and plan_llm and solve_llm) or (llm and plan_llm is None), 'Either specify only llm \
               without specify plan and solve, or specify only plan and solve without specifying llm, or specify \
               both llm and solve. Other situations are not allowed.'
        assert tools, "tools cannot be empty."
        self._planner = (plan_llm or llm).share(stream=dict(
            prefix='\nI will give a plan first:\n', prefix_color=Color.blue, color=Color.green) if stream else False)
        self._solver = (solve_llm or llm).share(stream=dict(
            prefix='\nI will solve the problem:\n', prefix_color=Color.blue, color=Color.green) if stream else False)
        self._name2tool = ToolManager(tools, return_trace=return_trace).tools_info
        with pipeline() as self._agent:
            self._agent.planner_pre_action = self._build_planner_prompt
            self._agent.planner = self._planner
            self._agent.parse_plan = self._parse_plan
            self._agent.woker = self._get_worker_evidences
            self._agent.solver_pre_action = self._build_solver_prompt | bind(input=self._agent.input)
            self._agent.solver = self._solver

    def _build_planner_prompt(self, input: str):
        prompt = P_PROMPT_PREFIX + "Tools can be one of the following:\n"
        for name, tool in self._name2tool.items():
            prompt += f"{name}[search query]: {tool.description}\n"
        prompt += P_FEWSHOT + "\n" + P_PROMPT_SUFFIX + input + "\n"
        globals['chat_history'][self._planner._module_id] = []
        return prompt

    def _parse_plan(self, response: str):
        LOG.debug(f"planner plans: {response}")
        plans = []
        evidence = {}
        for line in response.splitlines():
            if line.startswith("Plan"):
                plans.append(line)
            elif line.startswith("#") and line[1] == "E" and line[2].isdigit():
                e, tool_call = line.split("=", 1)
                e, tool_call = e.strip(), tool_call.strip()
                if len(e) == 3:
                    evidence[e] = tool_call
                else:
                    evidence[e] = "No evidence found"
        return package(plans, evidence)

    def _get_worker_evidences(self, plans: List[str], evidence: Dict[str, str]):
        worker_evidences = {}
        for e, tool_call in evidence.items():
            if "[" not in tool_call:
                worker_evidences[e] = tool_call
                continue
            tool, tool_input = tool_call.split("[", 1)
            tool_input = tool_input[:-1].strip("'").strip('"')
            # find variables in input and replace with previous evidences
            for var in re.findall(r"#E\d+", tool_input):
                if var in worker_evidences:
                    tool_input = tool_input.replace(var, "[" + worker_evidences[var] + "]")
            tool_instance = self._name2tool.get(tool)
            if tool_instance:
                worker_evidences[e] = tool_instance(tool_input)
            else:
                worker_evidences[e] = "No evidence found"

        worker_log = ""
        for idx, plan in enumerate(plans):
            e = f"#E{idx+1}"
            worker_log += f"{plan}\nEvidence:\n{worker_evidences[e]}\n"
        LOG.debug(f"worker_log: {worker_log}")
        return worker_log

    def _build_solver_prompt(self, worker_log, input):
        prompt = S_PROMPT_PREFIX + input + "\n" + worker_log + S_PROMPT_SUFFIX + input + "\n"
        globals['chat_history'][self._solver._module_id] = []
        return prompt

    def forward(self, query: str):
        return self._agent(query)

lazyllm.tools.rag.smart_embedding_index.SmartEmbeddingIndex

Bases: IndexBase

Source code in lazyllm/tools/rag/smart_embedding_index.py
class SmartEmbeddingIndex(IndexBase):
    def __init__(self, backend_type: str, **kwargs):
        if backend_type == 'milvus':
            self._store = MilvusStore(**kwargs)
        elif backend_type == 'map':
            self._store = MapStore(**kwargs)
        else:
            raise ValueError(f'unsupported backend [{backend_type}]')

    @override
    def update(self, nodes: List[DocNode]) -> None:
        self._store.update_nodes(nodes)

    @override
    def remove(self, uids: List[str], group_name: Optional[str] = None) -> None:
        self._store.remove_nodes(uids=uids)

    @override
    def query(self, *args, **kwargs) -> List[DocNode]:
        return self._store.query(*args, **kwargs)

lazyllm.tools.rag.doc_node.ImageDocNode

Bases: DocNode

专门用于处理RAG系统中图像内容的文档节点。

ImageDocNode继承自DocNode,为图像处理和嵌入生成提供专门的功能。它自动处理图像加载、用于嵌入的base64编码,以及用于LLM处理的PIL图像对象。

Parameters:

  • image_path (str) –

    图像文件的文件路径。这应该是一个有效的图像文件路径(例如.jpg、.png、.jpeg)。

  • uid (Optional[str], default: None ) –

    文档节点的唯一标识符。如果未提供,将自动生成UUID。

  • group (Optional[str], default: None ) –

    此节点所属的组名。用于组织和过滤节点。

  • embedding (Optional[Dict[str, List[float]]], default: None ) –

    图像的预计算嵌入。键是嵌入模型名称,值是嵌入向量。

  • parent (Optional[DocNode], default: None ) –

    文档层次结构中的父节点。用于构建文档树。

  • metadata (Optional[Dict[str, Any]], default: None ) –

    与图像节点关联的附加元数据。

  • global_metadata (Optional[Dict[str, Any]], default: None ) –

    适用于文档中所有节点的全局元数据。

  • text (Optional[str], default: None ) –

    图像的可选文本描述或标题。

Examples:

>>> from lazyllm.tools.rag.doc_node import ImageDocNode, MetadataMode
>>> import numpy as np
>>> image_node = ImageDocNode(
...     image_path="/home/mnt/yehongfei/Code/Test/framework.jpg",
...     text="这是一张照片"
)
>>> def clip_emb(content, modality="image"):
...     if modality == "image":
...         return [np.random.rand(512).tolist()]
...     return [np.random.rand(256).tolist()]
>>> embed_functions = {"clip": clip_emb}
>>> image_node.do_embedding(embed_functions)
>>> print(f"嵌入维度: {len(image_node.embedding['clip'])}")
>>> text_representation = image_node.get_text()
>>> content_representation = image_node.get_content(MetadataMode.EMBED)
>>> print(f"text属性: {text_representation}")
>>> print(f"content属性: {content_representation}")
Source code in lazyllm/tools/rag/doc_node.py
class ImageDocNode(DocNode):
    """专门用于处理RAG系统中图像内容的文档节点。

ImageDocNode继承自DocNode,为图像处理和嵌入生成提供专门的功能。它自动处理图像加载、用于嵌入的base64编码,以及用于LLM处理的PIL图像对象。

Args:
    image_path (str): 图像文件的文件路径。这应该是一个有效的图像文件路径(例如.jpg、.png、.jpeg)。
    uid (Optional[str]): 文档节点的唯一标识符。如果未提供,将自动生成UUID。
    group (Optional[str]): 此节点所属的组名。用于组织和过滤节点。
    embedding (Optional[Dict[str, List[float]]]): 图像的预计算嵌入。键是嵌入模型名称,值是嵌入向量。
    parent (Optional[DocNode]): 文档层次结构中的父节点。用于构建文档树。
    metadata (Optional[Dict[str, Any]]): 与图像节点关联的附加元数据。
    global_metadata (Optional[Dict[str, Any]]): 适用于文档中所有节点的全局元数据。
    text (Optional[str]): 图像的可选文本描述或标题。


Examples:
    >>> from lazyllm.tools.rag.doc_node import ImageDocNode, MetadataMode
    >>> import numpy as np
    >>> image_node = ImageDocNode(
    ...     image_path="/home/mnt/yehongfei/Code/Test/framework.jpg",
    ...     text="这是一张照片"
    )
    >>> def clip_emb(content, modality="image"):
    ...     if modality == "image":
    ...         return [np.random.rand(512).tolist()]
    ...     return [np.random.rand(256).tolist()]
    >>> embed_functions = {"clip": clip_emb}
    >>> image_node.do_embedding(embed_functions)
    >>> print(f"嵌入维度: {len(image_node.embedding['clip'])}")
    >>> text_representation = image_node.get_text()
    >>> content_representation = image_node.get_content(MetadataMode.EMBED)
    >>> print(f"text属性: {text_representation}")
    >>> print(f"content属性: {content_representation}")    
    """
    def __init__(self, image_path: str, uid: Optional[str] = None, group: Optional[str] = None,
                 embedding: Optional[Dict[str, List[float]]] = None, parent: Optional["DocNode"] = None,
                 metadata: Optional[Dict[str, Any]] = None, global_metadata: Optional[Dict[str, Any]] = None,
                 *, text: Optional[str] = None):
        super().__init__(uid, None, group, embedding, parent, metadata, global_metadata=global_metadata, text=text)
        self._image_path = image_path.strip()
        self._modality = 'image'

    def do_embedding(self, embed: Dict[str, Callable]) -> None:
        """使用提供的嵌入函数为图像生成嵌入。

此方法重写父类方法以处理图像特定的嵌入生成。它自动将图像转换为适当的格式(用于嵌入的base64),并使用图像模态调用嵌入函数。

Args:
    embed (Dict[str, Callable]): 嵌入函数字典。键是嵌入模型名称,值是接受(content, modality)并返回嵌入向量的可调用函数。
"""
        for k, e in embed.items():
            emb = e(self.get_content(MetadataMode.EMBED), modality=self._modality)
            generate_embed = {k: emb[0]}

        with self._lock:
            self.embedding = self.embedding or {}
            self.embedding = {**self.embedding, **generate_embed}

    def get_content(self, metadata_mode=MetadataMode.LLM) -> str:
        """根据元数据模式获取不同格式的图像内容。

此方法根据预期用例返回不同格式的图像内容。对于LLM处理,它返回PIL图像对象。对于嵌入生成,它返回base64编码的图像字符串。

Args:
    metadata_mode (MetadataMode, optional): 内容检索模式。默认为MetadataMode.LLM。
        - MetadataMode.LLM: 返回用于LLM处理的PIL图像对象
        - MetadataMode.EMBED: 返回用于嵌入生成的base64编码图像
        - 其他模式: 返回图像路径作为文本

**Returns:**

- Union[PIL.Image.Image, List[str], str]: 请求格式的图像内容。
"""
        if metadata_mode == MetadataMode.LLM:
            return Image.open(self._image_path)
        elif metadata_mode == MetadataMode.EMBED:
            image_base64, mime = _image_to_base64(self._image_path)
            return [f"data:{mime};base64,{image_base64}"]
        else:
            return self.get_text()

    @property
    def image_path(self):
        return self._image_path

    def get_text(self) -> str:  # Disable access to self._content
        """获取图像路径作为文本表示。

此方法重写父类方法以返回图像路径而不是内容字段,因为ImageDocNode不使用内容字段存储文本。

**Returns:**

- str: 图像文件路径。
"""
        return self._image_path

    @property
    def text(self) -> str:  # Disable access to self._content
        return self._image_path

do_embedding(embed)

使用提供的嵌入函数为图像生成嵌入。

此方法重写父类方法以处理图像特定的嵌入生成。它自动将图像转换为适当的格式(用于嵌入的base64),并使用图像模态调用嵌入函数。

Parameters:

  • embed (Dict[str, Callable]) –

    嵌入函数字典。键是嵌入模型名称,值是接受(content, modality)并返回嵌入向量的可调用函数。

Source code in lazyllm/tools/rag/doc_node.py
    def do_embedding(self, embed: Dict[str, Callable]) -> None:
        """使用提供的嵌入函数为图像生成嵌入。

此方法重写父类方法以处理图像特定的嵌入生成。它自动将图像转换为适当的格式(用于嵌入的base64),并使用图像模态调用嵌入函数。

Args:
    embed (Dict[str, Callable]): 嵌入函数字典。键是嵌入模型名称,值是接受(content, modality)并返回嵌入向量的可调用函数。
"""
        for k, e in embed.items():
            emb = e(self.get_content(MetadataMode.EMBED), modality=self._modality)
            generate_embed = {k: emb[0]}

        with self._lock:
            self.embedding = self.embedding or {}
            self.embedding = {**self.embedding, **generate_embed}

get_content(metadata_mode=MetadataMode.LLM)

根据元数据模式获取不同格式的图像内容。

此方法根据预期用例返回不同格式的图像内容。对于LLM处理,它返回PIL图像对象。对于嵌入生成,它返回base64编码的图像字符串。

Parameters:

  • metadata_mode (MetadataMode, default: LLM ) –

    内容检索模式。默认为MetadataMode.LLM。 - MetadataMode.LLM: 返回用于LLM处理的PIL图像对象 - MetadataMode.EMBED: 返回用于嵌入生成的base64编码图像 - 其他模式: 返回图像路径作为文本

Returns:

  • Union[PIL.Image.Image, List[str], str]: 请求格式的图像内容。
Source code in lazyllm/tools/rag/doc_node.py
    def get_content(self, metadata_mode=MetadataMode.LLM) -> str:
        """根据元数据模式获取不同格式的图像内容。

此方法根据预期用例返回不同格式的图像内容。对于LLM处理,它返回PIL图像对象。对于嵌入生成,它返回base64编码的图像字符串。

Args:
    metadata_mode (MetadataMode, optional): 内容检索模式。默认为MetadataMode.LLM。
        - MetadataMode.LLM: 返回用于LLM处理的PIL图像对象
        - MetadataMode.EMBED: 返回用于嵌入生成的base64编码图像
        - 其他模式: 返回图像路径作为文本

**Returns:**

- Union[PIL.Image.Image, List[str], str]: 请求格式的图像内容。
"""
        if metadata_mode == MetadataMode.LLM:
            return Image.open(self._image_path)
        elif metadata_mode == MetadataMode.EMBED:
            image_base64, mime = _image_to_base64(self._image_path)
            return [f"data:{mime};base64,{image_base64}"]
        else:
            return self.get_text()

get_text()

获取图像路径作为文本表示。

此方法重写父类方法以返回图像路径而不是内容字段,因为ImageDocNode不使用内容字段存储文本。

Returns:

  • str: 图像文件路径。
Source code in lazyllm/tools/rag/doc_node.py
    def get_text(self) -> str:  # Disable access to self._content
        """获取图像路径作为文本表示。

此方法重写父类方法以返回图像路径而不是内容字段,因为ImageDocNode不使用内容字段存储文本。

**Returns:**

- str: 图像文件路径。
"""
        return self._image_path

lazyllm.tools.rag.transform.AdaptiveTransform

Bases: NodeTransform

一个灵活的文档转换系统,根据文档模式应用不同的转换策略。

AdaptiveTransform允许您定义多种转换策略,并根据文档的文件路径或自定义模式匹配自动选择适当的转换方法。当您有不同类型的文档需要不同处理方法时,这特别有用。

Parameters:

  • transforms (Union[List[Union[TransformArgs, Dict]], Union[TransformArgs, Dict]]) –

    转换配置列表或单个转换配置。

  • num_workers (int, default: 0 ) –

    并行处理的工作线程数。默认为0。

Examples:

>>> from lazyllm.tools.rag.transform import AdaptiveTransform, DocNode, SentenceSplitter
>>> doc1 = DocNode(text="这是第一个文档的内容。它包含多个句子。")
>>> doc2 = DocNode(text="这是第二个文档的内容。")
>>> transforms = [
...     {
...         'f': SentenceSplitter,
...         'pattern': '*.txt',
...         'kwargs': {'chunk_size': 50, 'chunk_overlap': 10}
...     },
...     {
...         'f': SentenceSplitter,
...         'pattern': '*.pdf',
...         'kwargs': {'chunk_size': 100, 'chunk_overlap': 20}
...     }
... ]
>>> adaptive = AdaptiveTransform(transforms)
>>> results1 = adaptive.transform(doc1)
>>> print(f"文档1转换结果: {len(results1)} 个块")
>>> for i, result in enumerate(results1):
...     print(f"  块 {i+1}: {result.text}")
>>> results2 = adaptive.transform(doc2)
>>> print(f"文档2转换结果: {len(results2)} 个块")
>>> for i, result in enumerate(results2):
...     print(f"  块 {i+1}: {result.text}")
Source code in lazyllm/tools/rag/transform.py
class AdaptiveTransform(NodeTransform):
    """一个灵活的文档转换系统,根据文档模式应用不同的转换策略。

AdaptiveTransform允许您定义多种转换策略,并根据文档的文件路径或自定义模式匹配自动选择适当的转换方法。当您有不同类型的文档需要不同处理方法时,这特别有用。

Args:
    transforms (Union[List[Union[TransformArgs, Dict]], Union[TransformArgs, Dict]]): 转换配置列表或单个转换配置。
    num_workers (int, optional): 并行处理的工作线程数。默认为0。


Examples:
    >>> from lazyllm.tools.rag.transform import AdaptiveTransform, DocNode, SentenceSplitter
    >>> doc1 = DocNode(text="这是第一个文档的内容。它包含多个句子。")
    >>> doc2 = DocNode(text="这是第二个文档的内容。")
    >>> transforms = [
    ...     {
    ...         'f': SentenceSplitter,
    ...         'pattern': '*.txt',
    ...         'kwargs': {'chunk_size': 50, 'chunk_overlap': 10}
    ...     },
    ...     {
    ...         'f': SentenceSplitter,
    ...         'pattern': '*.pdf',
    ...         'kwargs': {'chunk_size': 100, 'chunk_overlap': 20}
    ...     }
    ... ]
    >>> adaptive = AdaptiveTransform(transforms)
    >>> results1 = adaptive.transform(doc1)
    >>> print(f"文档1转换结果: {len(results1)} 个块")
    >>> for i, result in enumerate(results1):
    ...     print(f"  块 {i+1}: {result.text}")
    >>> results2 = adaptive.transform(doc2)
    >>> print(f"文档2转换结果: {len(results2)} 个块")
    >>> for i, result in enumerate(results2):
    ...     print(f"  块 {i+1}: {result.text}")      
    """
    def __init__(self, transforms: Union[List[Union[TransformArgs, Dict]], Union[TransformArgs, Dict]],
                 num_workers: int = 0):
        super().__init__(num_workers=num_workers)
        if not isinstance(transforms, (tuple, list)): transforms = [transforms]
        self._transformers = [(t.get('pattern'), make_transform(t)) for t in transforms]

    def transform(self, document: DocNode, **kwargs) -> List[Union[str, DocNode]]:
        """根据模式匹配使用适当的转换策略转换文档。

此方法按顺序评估每个转换配置,并应用第一个匹配文档路径模式的转换。匹配逻辑支持glob模式和自定义可调用函数。

Args:
    document (DocNode): 要转换的文档节点。
    **kwargs: 传递给转换函数的附加关键字参数。

**Returns:**

- List[Union[str, DocNode]]: 转换结果列表(字符串或DocNode对象)。
"""
        if not isinstance(document, DocNode): LOG.warning(f'Invalud document type {type(document)} got')
        for pt, transform in self._transformers:
            if pt and isinstance(pt, str) and not pt.startswith('*'): pt = os.path.join(str(os.cwd()), pt)
            if not pt or (callable(pt) and pt(document.docpath)) or (
                    isinstance(pt, str) and fnmatch.fnmatch(document.docpath, pt)):
                return transform(document, **kwargs)
        LOG.warning(f'No transform found for document {document.docpath} with group name `{self._name}`')
        return []

transform(document, **kwargs)

根据模式匹配使用适当的转换策略转换文档。

此方法按顺序评估每个转换配置,并应用第一个匹配文档路径模式的转换。匹配逻辑支持glob模式和自定义可调用函数。

Parameters:

  • document (DocNode) –

    要转换的文档节点。

  • **kwargs

    传递给转换函数的附加关键字参数。

Returns:

  • List[Union[str, DocNode]]: 转换结果列表(字符串或DocNode对象)。
Source code in lazyllm/tools/rag/transform.py
    def transform(self, document: DocNode, **kwargs) -> List[Union[str, DocNode]]:
        """根据模式匹配使用适当的转换策略转换文档。

此方法按顺序评估每个转换配置,并应用第一个匹配文档路径模式的转换。匹配逻辑支持glob模式和自定义可调用函数。

Args:
    document (DocNode): 要转换的文档节点。
    **kwargs: 传递给转换函数的附加关键字参数。

**Returns:**

- List[Union[str, DocNode]]: 转换结果列表(字符串或DocNode对象)。
"""
        if not isinstance(document, DocNode): LOG.warning(f'Invalud document type {type(document)} got')
        for pt, transform in self._transformers:
            if pt and isinstance(pt, str) and not pt.startswith('*'): pt = os.path.join(str(os.cwd()), pt)
            if not pt or (callable(pt) and pt(document.docpath)) or (
                    isinstance(pt, str) and fnmatch.fnmatch(document.docpath, pt)):
                return transform(document, **kwargs)
        LOG.warning(f'No transform found for document {document.docpath} with group name `{self._name}`')
        return []

lazyllm.tools.rag.rerank.ModuleReranker

Bases: Reranker

使用可训练模块根据查询相关性重新排序文档的重排序器。

ModuleReranker是一个专门的重排序器,利用可训练模型(如BGE-reranker、Cohere rerank等)来提高检索文档的相关性。它接收文档列表和查询,然后返回按相关性分数重新排序的文档。

Parameters:

  • name (str, default: 'ModuleReranker' ) –

    重排序器的名称。默认为"ModuleReranker"。

  • model (Union[Callable, str], default: None ) –

    重排序模型。可以是模型名称(字符串)或可调用函数。

  • target (Optional[str], default: None ) –

    默认为None。

  • output_format (Optional[str], default: None ) –

    输出处理格式。默认为None。

  • join (Union[bool, str], default: False ) –

    是否连接结果。默认为False。

  • **kwargs

    传递给重排序模型模型的附加关键字参数。

Examples:

>>> from lazyllm.tools.rag.rerank import ModuleReranker, DocNode
>>> def simple_reranker(query, documents, top_n):
...     query_lower = query.lower()
...     scores = []
...     for i, doc in enumerate(documents):
...         score = sum(1 for word in query_lower.split() if word in doc)
...         scores.append((i, score))
...     scores.sort(key=lambda x: x[1], reverse=True)
...     return scores[:top_n]
>>> reranker = ModuleReranker(
...     model=simple_reranker,
...     topk=2
... )
>>> docs = [
...     DocNode(text="机器学习算法在数据分析中应用广泛"),
...     DocNode(text="深度学习模型需要大量训练数据"),
...     DocNode(text="自然语言处理技术发展迅速"),
...     DocNode(text="计算机视觉在自动驾驶中的应用")
... ]
>>> query = "机器学习"
>>> results = reranker.forward(docs, query)
>>> for i, doc in enumerate(results):
...     print(f"  {i+1}. : {doc.text}")
...     print(f"     相关性分数: {doc.relevance_score:.4f}")
Source code in lazyllm/tools/rag/rerank.py
@Reranker.register_reranker()
class ModuleReranker(Reranker):
    """使用可训练模块根据查询相关性重新排序文档的重排序器。

ModuleReranker是一个专门的重排序器,利用可训练模型(如BGE-reranker、Cohere rerank等)来提高检索文档的相关性。它接收文档列表和查询,然后返回按相关性分数重新排序的文档。

Args:
    name (str): 重排序器的名称。默认为"ModuleReranker"。
    model (Union[Callable, str]): 重排序模型。可以是模型名称(字符串)或可调用函数。
    target (Optional[str]): 默认为None。
    output_format (Optional[str]): 输出处理格式。默认为None。
    join (Union[bool, str]): 是否连接结果。默认为False。
    **kwargs: 传递给重排序模型模型的附加关键字参数。


Examples:
    >>> from lazyllm.tools.rag.rerank import ModuleReranker, DocNode
    >>> def simple_reranker(query, documents, top_n):
    ...     query_lower = query.lower()
    ...     scores = []
    ...     for i, doc in enumerate(documents):
    ...         score = sum(1 for word in query_lower.split() if word in doc)
    ...         scores.append((i, score))
    ...     scores.sort(key=lambda x: x[1], reverse=True)
    ...     return scores[:top_n]
    >>> reranker = ModuleReranker(
    ...     model=simple_reranker,
    ...     topk=2
    ... )
    >>> docs = [
    ...     DocNode(text="机器学习算法在数据分析中应用广泛"),
    ...     DocNode(text="深度学习模型需要大量训练数据"),
    ...     DocNode(text="自然语言处理技术发展迅速"),
    ...     DocNode(text="计算机视觉在自动驾驶中的应用")
    ... ]
    >>> query = "机器学习"
    >>> results = reranker.forward(docs, query)
    >>> for i, doc in enumerate(results):
    ...     print(f"  {i+1}. : {doc.text}")
    ...     print(f"     相关性分数: {doc.relevance_score:.4f}")        
    """

    def __init__(self, name: str = "ModuleReranker", model: Union[Callable, str] = None, target: Optional[str] = None,
                 output_format: Optional[str] = None, join: Union[bool, str] = False, **kwargs) -> None:
        super().__init__(name, target, output_format, join, **kwargs)
        assert model is not None, "Reranker model must be specified as a model name or a callable."
        if isinstance(model, str):
            self._reranker = lazyllm.TrainableModule(model)
        else:
            self._reranker = model

    def forward(self, nodes: List[DocNode], query: str = "") -> List[DocNode]:
        """重排序器的前向传播,根据与查询的相关性重新排序文档。

此方法接收文档列表和查询,然后使用底层重排序模型对文档进行评分和重新排序。文档以MetadataMode.EMBED格式处理,以确保与重排序模型的兼容性。

Args:
    nodes (List[DocNode]): 要重排序的文档节点列表。
    query (str): 用于排序文档的查询字符串。默认为""。

**Returns:**

- List[DocNode]: 按相关性分数重新排序的文档节点列表,添加了relevance_score属性。
"""
        if not nodes:
            return self._post_process([])

        docs = [node.get_text(metadata_mode=MetadataMode.EMBED) for node in nodes]
        top_n = self._kwargs['topk'] if 'topk' in self._kwargs else len(docs)
        sorted_indices = self._reranker(query, documents=docs, top_n=top_n)
        results = []
        for index, relevance_score in sorted_indices:
            results.append(nodes[index].with_score(relevance_score))
        LOG.debug(f"Rerank use `{self._name}` and get nodes: {results}")
        return self._post_process(results)

forward(nodes, query='')

重排序器的前向传播,根据与查询的相关性重新排序文档。

此方法接收文档列表和查询,然后使用底层重排序模型对文档进行评分和重新排序。文档以MetadataMode.EMBED格式处理,以确保与重排序模型的兼容性。

Parameters:

  • nodes (List[DocNode]) –

    要重排序的文档节点列表。

  • query (str, default: '' ) –

    用于排序文档的查询字符串。默认为""。

Returns:

  • List[DocNode]: 按相关性分数重新排序的文档节点列表,添加了relevance_score属性。
Source code in lazyllm/tools/rag/rerank.py
    def forward(self, nodes: List[DocNode], query: str = "") -> List[DocNode]:
        """重排序器的前向传播,根据与查询的相关性重新排序文档。

此方法接收文档列表和查询,然后使用底层重排序模型对文档进行评分和重新排序。文档以MetadataMode.EMBED格式处理,以确保与重排序模型的兼容性。

Args:
    nodes (List[DocNode]): 要重排序的文档节点列表。
    query (str): 用于排序文档的查询字符串。默认为""。

**Returns:**

- List[DocNode]: 按相关性分数重新排序的文档节点列表,添加了relevance_score属性。
"""
        if not nodes:
            return self._post_process([])

        docs = [node.get_text(metadata_mode=MetadataMode.EMBED) for node in nodes]
        top_n = self._kwargs['topk'] if 'topk' in self._kwargs else len(docs)
        sorted_indices = self._reranker(query, documents=docs, top_n=top_n)
        results = []
        for index, relevance_score in sorted_indices:
            results.append(nodes[index].with_score(relevance_score))
        LOG.debug(f"Rerank use `{self._name}` and get nodes: {results}")
        return self._post_process(results)

lazyllm.tools.rag.utils.DocListManager

Bases: ABC

抽象基类,用于管理文档列表和监控文档目录变化。

Parameters:

  • path

    要监控的文档目录路径。

  • name

    管理器名称。

  • enable_path_monitoring

    启用路径监控。

Examples:

>>> import lazyllm
>>> from lazyllm.rag.utils import DocListManager
>>> manager = DocListManager(path='your_file_path/', name="test_manager", enable_path_monitoring=False)
>>> added_docs = manager.add_files([test_file_list])
>>> manager.enable_path_monitoring(True)
>>> deleted = manager.delete_files([delete_file_list])
Source code in lazyllm/tools/rag/utils.py
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
class DocListManager(ABC):
    """抽象基类,用于管理文档列表和监控文档目录变化。

Args:
    path:要监控的文档目录路径。
    name:管理器名称。
    enable_path_monitoring:启用路径监控。



Examples:

    >>> import lazyllm
    >>> from lazyllm.rag.utils import DocListManager
    >>> manager = DocListManager(path='your_file_path/', name="test_manager", enable_path_monitoring=False)
    >>> added_docs = manager.add_files([test_file_list])
    >>> manager.enable_path_monitoring(True)
    >>> deleted = manager.delete_files([delete_file_list])
    """
    DEFAULT_GROUP_NAME = '__default__'
    __pool__ = dict()

    class Status:
        all = 'all'
        waiting = 'waiting'
        working = 'working'
        success = 'success'
        failed = 'failed'
        deleting = 'deleting'
        # deleted is no longer used
        deleted = 'deleted'

    def __init__(self, path, name, enable_path_monitoring=True):
        self._path = path
        self._name = name
        lazyllm.LOG.info(f'DocManager use file-system monitoring worker: {enable_path_monitoring}')
        self._id = hashlib.sha256(f'{name}@+@{path}'.encode()).hexdigest()
        if not os.path.isabs(path):
            raise ValueError(f"path [{path}] is not an absolute path")

        self._init_sql()
        self._delete_nonexistent_docs_on_startup()

        self._monitor_thread = threading.Thread(target=self._monitor_directory_worker)
        self._monitor_thread.daemon = True
        self._monitor_continue = True
        self._enable_path_monitoring = enable_path_monitoring
        self._init_monitor_event = threading.Event()
        if self._enable_path_monitoring:
            self._monitor_thread.start()
            self._init_monitor_event.wait()

    def _delete_nonexistent_docs_on_startup(self):
        ids = [row[0] for row in self.list_kb_group_files(details=True)
               if not Path(row[1]).exists()]
        if ids: self.delete_files(ids)

    def __new__(cls, *args, **kw):
        if cls is not DocListManager:
            return super().__new__(cls)
        return super().__new__(__class__.__pool__[config['default_dlmanager']])

    def init_tables(self) -> 'DocListManager':
        """确保数据库表默认分组存在。
"""
        if not self.table_inited():
            self._init_tables()
        # in case of using after relase
        self.add_kb_group(DocListManager.DEFAULT_GROUP_NAME)
        return self

    def _monitor_directory(self) -> Set[str]:
        files_list = []
        for root, _, files in os.walk(self._path):
            files = [os.path.join(root, file_path) for file_path in files]
            files_list.extend(files)
        return set(files_list)

    # Actually it shoule be "set_docs_status_deleting"
    def delete_files(self, file_ids: List[str]) -> List[DocPartRow]:
        """将与文件关联的知识库条目设为删除中,并由各知识库进行异步删除解析结果及关联记录。

Args:
    file_ids (list of str): 要删除的文件ID列表
"""
        document_list = self.update_file_status(file_ids, DocListManager.Status.deleting)
        self.update_kb_group(cond_file_ids=file_ids, new_status=DocListManager.Status.deleting)
        return document_list

    @abstractmethod
    def table_inited(self):
        """检查数据库中的 `documents` 表是否已初始化。此方法在访问数据库时确保线程安全。
判断数据库中是否存在 `documents` 表。
返回值:
    bool: 如果 `documents` 表存在,返回 `True`;否则返回 `False`。
说明:
    - 使用线程安全锁 (`self._db_lock`) 确保对数据库的安全访问。
    - 通过 `self._db_path` 连接 SQLite 数据库,并使用 `check_same_thread` 配置选项。
    - 执行 SQL 查询:`SELECT name FROM sqlite_master WHERE type='table' AND name='documents'` 来检查表是否存在。
"""
        pass

    @abstractmethod
    def _init_tables(self): pass

    @abstractmethod
    def validate_paths(self, paths: List[str]) -> Tuple[bool, str, List[bool]]:
        """验证一组文件路径,以确保它们可以被正常处理。
此方法检查提供的路径是否是新的、已处理的或当前正在处理的,并确保处理文档时不会发生冲突。
参数:
    paths (List[str]): 要验证的文件路径列表。
返回值:
    Tuple[bool, str, List[bool]]: 返回一个元组,包括:
        - `bool`: 如果所有路径有效,则返回 `True`;否则返回 `False`。
        - `str`: 表示成功或失败原因的消息。
        - `List[bool]`: 一个布尔值列表,每个元素对应一个路径是否为新路径(`True` 表示新路径,`False` 表示已存在)。
说明:
    - 如果任何文档仍在处理中或需要重新解析,该方法会返回 `False`,并附带相应的错误消息。
    - 方法通过数据库会话和线程安全锁 (`self._db_lock`) 检索文档状态信息。
    - 不安全状态包括 `working` 和 `waiting`。

"""
        pass

    @abstractmethod
    def update_need_reparsing(self, doc_id: str, need_reparse: bool):
        """更新 `KBGroupDocuments` 表中某个文档的 `need_reparse` 状态。
此方法设置指定文档的 `need_reparse` 标志,并可选限定到特定分组。
参数:
    doc_id (str): 要更新的文档ID。
    need_reparse (bool): `need_reparse` 标志的新值。
    group_name (Optional[str]): 如果提供,仅对指定分组应用更新;如果未提供,则对包含该文档的所有分组应用更新。
说明:
    - 使用线程安全锁 (`self._db_lock`) 确保数据库访问安全。
    - `group_name` 参数允许将更新限定到特定分组;如果未提供,则更新应用于包含该文档的所有分组。
    - 方法会立刻将更改提交到数据库。
"""
        pass

    @abstractmethod
    def list_files(self, limit: Optional[int] = None, details: bool = False,
                   status: Union[str, List[str]] = Status.all,
                   exclude_status: Optional[Union[str, List[str]]] = None):
        """从 `documents` 表中列出文件,并支持过滤、限制返回结果以及返回详细信息。
此方法根据指定的条件,从数据库中检索文件ID或详细文件信息。
参数:
    limit (Optional[int]): 返回的最大文件数量。如果为 `None`,则返回所有匹配的文件。
    details (bool): 是否返回详细的文件信息(`True`)或仅返回文件ID(`False`)。
    status (Union[str, List[str]]): 要包含的状态或状态列表,默认为所有状态。
    exclude_status (Optional[Union[str, List[str]]]): 要排除的状态或状态列表,默认为 `None`。
返回值:
    List: 如果 `details=False`,则返回文件ID列表;如果 `details=True`,则返回详细文件行的列表。
说明:
    - 该方法根据 `status` 和 `exclude_status` 条件动态构造查询。
    - 使用线程安全锁 (`self._db_lock`) 确保数据库访问安全。
    - 如果指定了 `limit`,查询会附加 `LIMIT` 子句。
"""
        pass

    @abstractmethod
    def get_docs(self, doc_ids: List[str]) -> List[KBDocument]:
        """从数据库中检索类型为 `KBDocument` 的文档对象,基于提供的文档 ID 列表。

Args:
    doc_ids (List[str]): 要获取的文档 ID 列表。
**Returns:**
    List[KBDocument]: 与提供的文档 ID 对应的 `KBDocument` 对象列表。如果没有找到文档,将返回空列表。
说明:
    - 使用线程安全锁 (`self._db_lock`) 确保数据库访问的安全性。
    - 查询使用 SQL 的 `IN` 子句,通过 `doc_id` 字段进行过滤。
    - 如果 `doc_ids` 为空,函数将直接返回空列表,而不会查询数据库。
"""
        pass

    @abstractmethod
    def set_docs_new_meta(self, doc_meta: Dict[str, dict]):
        """批量更新文档的元数据。

Args:
    doc_meta (Dict[str, dict]): 文档ID到新元数据的映射字典。

"""
        pass

    @abstractmethod
    def fetch_docs_changed_meta(self, group: str) -> List[DocMetaChangedRow]:
        """获取指定组中元数据已更改的文档,并将其 `new_meta` 字段重置为 `None`。
此方法检索元数据已更改(即 `new_meta` 不为 `None`)的所有文档,基于提供的组名。检索后,会将这些文档的 `new_meta` 字段重置为 `None`。

Args:
    group (str): 用于过滤文档的组名。
**Returns:**
    List[DocMetaChangedRow]: 包含文档 `doc_id` 和 `new_meta` 字段的行列表,表示元数据已更改的文档。
说明:
    - 使用线程安全锁 (`self._db_lock`) 确保数据库访问安全。
    - 方法通过 SQL `JOIN` 操作连接 `KBDocument` 和 `KBGroupDocuments` 表以检索相关行。
    - 在获取数据后,将受影响行的 `new_meta` 字段更新为 `None`,并将更改提交到数据库。
"""
        pass

    @abstractmethod
    def list_all_kb_group(self):
        """列出所有知识库分组的名称。

**Returns:**
- list: 知识库分组名称列表。
"""
        pass

    @abstractmethod
    def add_kb_group(self, name):
        """添加一个新的知识库分组。

Args:
    name (str): 要添加的分组名称。
"""
        pass

    @abstractmethod
    def list_kb_group_files(self, group: str = None, limit: Optional[int] = None, details: bool = False,
                            status: Union[str, List[str]] = Status.all,
                            exclude_status: Optional[Union[str, List[str]]] = None,
                            upload_status: Union[str, List[str]] = Status.all,
                            exclude_upload_status: Optional[Union[str, List[str]]] = None,
                            need_reparse: Optional[bool] = False):
        """列出指定知识库组中的文件。

Args:
    group (str): 用于过滤文件的 KB 组名。默认为 `None`。
    limit (Optional[int]): 返回的最大文件数量。如果为 `None`,则返回所有匹配的文件。
    details (bool): 返回详细的文件信息或仅返回文件 ID 和路径。
    status (Union[str, List[str]]): 包含在结果中的 KB 组状态或状态列表。默认为所有状态。
    exclude_status (Optional[Union[str, List[str]]): 从结果中排除的 KB 组状态或状态列表。默认为 `None`。
    upload_status (Union[str, List[str]]): 包含在结果中的文档上传状态或状态列表。默认为所有状态。
    exclude_upload_status (Optional[Union[str, List[str]]): 从结果中排除的文档上传状态或状态列表。默认为 `None`。
    need_reparse (Optional[bool]): 过滤需要重新解析的文件或不需要重新解析的文件。默认为 `None`。
**Returns:**:
    List: 如果 `details=False`,返回包含 `(doc_id, path)` 的元组列表。
          如果 `details=True`,返回包含附加元数据的详细行列表。
说明:
    - 方法根据提供的过滤条件动态构建 SQL 查询。
    - 使用线程安全锁 (`self._db_lock`) 确保多线程环境下的数据库访问安全。
    - 如果 `status` 或 `upload_status` 参数为列表,则会使用 SQL 的 `IN` 子句进行处理。
"""
        pass

    def add_files(
        self,
        files: List[str],
        metadatas: Optional[List[Dict[str, Any]]] = None,
        status: Optional[str] = Status.waiting,
        batch_size: int = 64,
    ) -> List[DocPartRow]:
        """批量向文档列表中添加文件,可选附加元数据、状态,并支持分批处理。
此方法将文件列表添加到数据库中,并为每个文件设置可选的元数据和初始状态。文件会以批量方式处理以提高效率。在文件添加完成后,它们会自动关联到默认的知识库 (KB) 组。
Args:
    files (List[str]): 添加的文件路径列表。
    metadatas (Optional[List[Dict[str, Any]]]): 与文件对应的元数据字典列表。默认为 `None`。
    status (Optional[str]): 添加文件的初始状态。默认为 `Status.waiting`。
    batch_size (int): 每批处理的文件数量。默认为 64。
**Returns:**:
    List[DocPartRow]: 包含已添加文件及其相关信息的 `DocPartRow` 对象列表。
说明:
    - 方法首先通过辅助函数 `_add_doc_records` 创建文档记录。
    - 文件添加后,会自动关联到默认的知识库组 (`DocListManager.DEFAULT_GROUP_NAME`)。
    - 批量处理确保在添加大量文件时具有良好的可扩展性。
"""
        documents = self._add_doc_records(files, metadatas, status, batch_size)
        if documents:
            self.add_files_to_kb_group([doc.doc_id for doc in documents], group=DocListManager.DEFAULT_GROUP_NAME)
        return documents

    @abstractmethod
    def _get_all_docs(self): pass

    @abstractmethod
    def _get_docs(self, to_be_added_doc_ids: List, to_be_deleted_doc_ids: List, filter_status_list: List): pass

    @abstractmethod
    def _add_doc_records(self, files: List[str], metadatas: Optional[List] = None,
                         status: Optional[str] = Status.waiting, batch_size: int = 64) -> List[DocPartRow]: pass

    @abstractmethod
    def delete_unreferenced_doc(self):
        """删除数据库中标记为 "删除中" 且不再被引用的文档。
此方法从数据库中删除满足以下条件的文档:
1. 文档状态为 `DocListManager.Status.deleting`。
2. 文档的引用计数 (`count`) 为 0。
"""
        pass

    @abstractmethod
    def get_docs_need_reparse(self, group: Optional[str] = None) -> List[KBDocument]:
        """获取需要重新解析 (`need_reparse=True`)的指定组中的文档。
此方法检索标记为需要重新解析 (`need_reparse=True`) 的文档,基于提供的组名。仅包含状态为 `success` 或 `failed` 的文档。
Args:
    group (str): 用于过滤文档的组名。
**Returns:**:
    List[KBDocument]: 需要重新解析的 `KBDocument` 对象列表。
说明:
    - 使用线程安全锁 (`self._db_lock`) 确保多线程环境下的数据库访问安全。
    - 查询通过 SQL `JOIN` 操作连接 `KBDocument` 和 `KBGroupDocuments` 表,并基于组名和重新解析状态进行过滤。
    - 仅状态为 `success` 或 `failed` 且 `need_reparse=True` 的文档会被检索出来。
"""
        pass

    @abstractmethod
    def get_existing_paths_by_pattern(self, file_path: str) -> List[str]:
        """根据给定的模式,检索符合条件的文档路径。
此方法从数据库中获取所有符合提供的 SQL `LIKE` 模式的文档路径。
Args:
    pattern (str): 用于过滤文档路径的 SQL `LIKE` 模式。例如,`%example%` 匹配包含单词 "example" 的路径。
**Returns:**:
    List[str]: 符合给定模式的文档路径列表。如果没有匹配的路径,则返回空列表。
说明:
    - 使用线程安全锁 (`self._db_lock`) 确保多线程环境下的数据库访问安全。
    - SQL 查询中的 `LIKE` 操作符用于对文档路径进行模式匹配。
"""
        pass

    @abstractmethod
    def update_file_message(self, fileid: str, **kw):
        """更新指定文件的消息。

Args:
    fileid (str): 文件ID。
    **kw: 需要更新的其他键值对。
"""
        pass

    @abstractmethod
    def update_file_status(self, file_ids: List[str], status: str,
                           cond_status_list: Union[None, List[str]] = None) -> List[DocPartRow]:
        """更新指定文件的状态。

Args:
    file_ids (list of str): 更新状态的文件ID列表。
    status (str): 目标状态。
    cond_status_list(Union[None, List[str]]):限制只更新处于这些状态的文档
"""
        pass

    @abstractmethod
    def add_files_to_kb_group(self, file_ids: List[str], group: str):
        """将文件添加到指定的知识库分组中。

Args:
    file_ids (list of str): 要添加的文件ID列表。
    group (str): 要添加的分组名称。
"""
        pass

    @abstractmethod
    def delete_files_from_kb_group(self, file_ids: List[str], group: str):
        """从指定的知识库分组中删除文件。

Args:
    file_ids (list of str): 要删除的文件ID列表。
    group (str): 分组名称。
"""
        pass

    @abstractmethod
    def get_file_status(self, fileid: str):
        """获取指定文件的状态。

Args:
    fileid (str): 文件ID。

**Returns:**
- str: 文件的当前状态。
"""
        pass

    @abstractmethod
    def update_kb_group(self, cond_file_ids: List[str], cond_group: Optional[str] = None,
                        cond_status_list: Optional[List[str]] = None, new_status: Optional[str] = None,
                        new_need_reparse: Optional[bool] = None) -> List[GroupDocPartRow]:
        """更新指定知识库分组中的内容。

Args:
    cond_file_ids (list of str, optional): 过滤使用的文件ID列表,默认为None。
    cond_group (str, optional): 过滤使用的知识库分组名称,默认为None。
    cond_status_list (list of str, optional): 过滤使用的状态列表,默认为None。
    new_status (str, optional): 新状态, 默认为None。
    new_need_reparse (bool, optinoal): 新的是否需重解析标志。

**Returns:**
- list: 得到更新的列表list of (doc_id, group_name)
"""
        pass

    @abstractmethod
    def release(self):
        """释放当前管理器的资源。

"""
        pass

    @property
    def enable_path_monitoring(self):
        """启用或禁用文档管理器的路径监控功能。
此方法用于启用或禁用文档管理器的路径监控功能。当启用时,会启动一个监控线程处理与路径相关的操作;当禁用时,会停止该线程并等待它终止。
Args:
    val (bool): 启用或禁用路径监控。
说明:
    - 如果 `val` 为 `True`,路径监控功能会通过将 `_monitor_continue` 设置为 `True` 并启动 `_monitor_thread` 来启用。
    - 如果 `val` 为 `False`,路径监控功能会通过将 `_monitor_continue` 设置为 `False` 并等待 `_monitor_thread` 终止来禁用。
    - 方法在管理监控线程时确保线程操作是安全的。
"""
        return self._enable_path_monitoring

    @enable_path_monitoring.setter
    def enable_path_monitoring(self, val: bool):
        """启用或禁用文档管理器的路径监控功能。
此方法用于启用或禁用文档管理器的路径监控功能。当启用时,会启动一个监控线程处理与路径相关的操作;当禁用时,会停止该线程并等待它终止。
Args:
    val (bool): 启用或禁用路径监控。
说明:
    - 如果 `val` 为 `True`,路径监控功能会通过将 `_monitor_continue` 设置为 `True` 并启动 `_monitor_thread` 来启用。
    - 如果 `val` 为 `False`,路径监控功能会通过将 `_monitor_continue` 设置为 `False` 并等待 `_monitor_thread` 终止来禁用。
    - 方法在管理监控线程时确保线程操作是安全的。
"""
        self._enable_path_monitoring = (val is True)
        if val is True:
            self._monitor_continue = True
            self._monitor_thread.start()
        else:
            self._monitor_continue = False
            if self._monitor_thread.is_alive():
                self._monitor_thread.join()

    def _monitor_directory_worker(self):
        failed_files_count = defaultdict(int)
        docs_all = self._get_all_docs()

        previous_files = set([doc.path for doc in docs_all])
        skip_files = set()
        is_first_run = True
        while self._monitor_continue:
            # 1. Scan files in the directory, find added and deleted files
            current_files = set(self._monitor_directory())
            to_be_added_files = current_files - previous_files - skip_files
            to_be_deleted_files = previous_files - current_files - skip_files
            failed_files = set()

            to_be_added_doc_ids = set([gen_docid(ele) for ele in to_be_added_files])
            to_be_deleted_doc_ids = set([gen_docid(ele) for ele in to_be_deleted_files])
            failed_doc_ids = set()
            filter_status_list = [DocListManager.Status.success,
                                  DocListManager.Status.failed, DocListManager.Status.waiting]

            docs_not_expected, docs_expected = self._get_docs(to_be_added_doc_ids,
                                                              to_be_deleted_doc_ids, filter_status_list)

            # 2. Skip new files that are already in the database
            failed_files.update([doc.path for doc in docs_not_expected])
            failed_doc_ids.update([doc.doc_id for doc in docs_not_expected])
            to_be_added_files -= failed_files
            # Actually it is add to doc with success status, then add to kb_group with waiting status
            self.add_files(list(to_be_added_files), status=DocListManager.Status.success)

            # 3. Skip deleted files that are: 1. not in the database, 2. status not success/failed/waiting
            safe_to_delete_files = set([doc.path for doc in docs_expected])
            safe_to_delete_doc_ids = set([doc.doc_id for doc in docs_expected])
            failed_doc_ids.update(to_be_deleted_doc_ids - safe_to_delete_doc_ids)
            failed_files.update(to_be_deleted_files - safe_to_delete_files)
            to_be_deleted_files = safe_to_delete_files
            to_be_deleted_doc_ids = safe_to_delete_doc_ids
            self.delete_files(list(to_be_deleted_doc_ids))

            # 4. update skip_files
            for ele in failed_files:
                failed_files_count[ele] += 1
                if failed_files_count[ele] >= 3:
                    skip_files.add(ele)
            # update previous files, while failed files will be re-processed in the next loop
            previous_files = (current_files | to_be_added_files) - to_be_deleted_files
            if is_first_run:
                self._init_monitor_event.set()
            is_first_run = False
            time.sleep(10)
        lazyllm.LOG.warning("END MONITORING")

    def __del__(self):
        self.enable_path_monitoring = False

enable_path_monitoring property writable

启用或禁用文档管理器的路径监控功能。 此方法用于启用或禁用文档管理器的路径监控功能。当启用时,会启动一个监控线程处理与路径相关的操作;当禁用时,会停止该线程并等待它终止。 Args: val (bool): 启用或禁用路径监控。 说明: - 如果 valTrue,路径监控功能会通过将 _monitor_continue 设置为 True 并启动 _monitor_thread 来启用。 - 如果 valFalse,路径监控功能会通过将 _monitor_continue 设置为 False 并等待 _monitor_thread 终止来禁用。 - 方法在管理监控线程时确保线程操作是安全的。

add_files(files, metadatas=None, status=Status.waiting, batch_size=64)

批量向文档列表中添加文件,可选附加元数据、状态,并支持分批处理。 此方法将文件列表添加到数据库中,并为每个文件设置可选的元数据和初始状态。文件会以批量方式处理以提高效率。在文件添加完成后,它们会自动关联到默认的知识库 (KB) 组。 Args: files (List[str]): 添加的文件路径列表。 metadatas (Optional[List[Dict[str, Any]]]): 与文件对应的元数据字典列表。默认为 None。 status (Optional[str]): 添加文件的初始状态。默认为 Status.waiting。 batch_size (int): 每批处理的文件数量。默认为 64。 Returns:: List[DocPartRow]: 包含已添加文件及其相关信息的 DocPartRow 对象列表。 说明: - 方法首先通过辅助函数 _add_doc_records 创建文档记录。 - 文件添加后,会自动关联到默认的知识库组 (DocListManager.DEFAULT_GROUP_NAME)。 - 批量处理确保在添加大量文件时具有良好的可扩展性。

Source code in lazyllm/tools/rag/utils.py
    def add_files(
        self,
        files: List[str],
        metadatas: Optional[List[Dict[str, Any]]] = None,
        status: Optional[str] = Status.waiting,
        batch_size: int = 64,
    ) -> List[DocPartRow]:
        """批量向文档列表中添加文件,可选附加元数据、状态,并支持分批处理。
此方法将文件列表添加到数据库中,并为每个文件设置可选的元数据和初始状态。文件会以批量方式处理以提高效率。在文件添加完成后,它们会自动关联到默认的知识库 (KB) 组。
Args:
    files (List[str]): 添加的文件路径列表。
    metadatas (Optional[List[Dict[str, Any]]]): 与文件对应的元数据字典列表。默认为 `None`。
    status (Optional[str]): 添加文件的初始状态。默认为 `Status.waiting`。
    batch_size (int): 每批处理的文件数量。默认为 64。
**Returns:**:
    List[DocPartRow]: 包含已添加文件及其相关信息的 `DocPartRow` 对象列表。
说明:
    - 方法首先通过辅助函数 `_add_doc_records` 创建文档记录。
    - 文件添加后,会自动关联到默认的知识库组 (`DocListManager.DEFAULT_GROUP_NAME`)。
    - 批量处理确保在添加大量文件时具有良好的可扩展性。
"""
        documents = self._add_doc_records(files, metadatas, status, batch_size)
        if documents:
            self.add_files_to_kb_group([doc.doc_id for doc in documents], group=DocListManager.DEFAULT_GROUP_NAME)
        return documents

add_files_to_kb_group(file_ids, group) abstractmethod

将文件添加到指定的知识库分组中。

Parameters:

  • file_ids (list of str) –

    要添加的文件ID列表。

  • group (str) –

    要添加的分组名称。

Source code in lazyllm/tools/rag/utils.py
    @abstractmethod
    def add_files_to_kb_group(self, file_ids: List[str], group: str):
        """将文件添加到指定的知识库分组中。

Args:
    file_ids (list of str): 要添加的文件ID列表。
    group (str): 要添加的分组名称。
"""
        pass

add_kb_group(name) abstractmethod

添加一个新的知识库分组。

Parameters:

  • name (str) –

    要添加的分组名称。

Source code in lazyllm/tools/rag/utils.py
    @abstractmethod
    def add_kb_group(self, name):
        """添加一个新的知识库分组。

Args:
    name (str): 要添加的分组名称。
"""
        pass

delete_files(file_ids)

将与文件关联的知识库条目设为删除中,并由各知识库进行异步删除解析结果及关联记录。

Parameters:

  • file_ids (list of str) –

    要删除的文件ID列表

Source code in lazyllm/tools/rag/utils.py
    def delete_files(self, file_ids: List[str]) -> List[DocPartRow]:
        """将与文件关联的知识库条目设为删除中,并由各知识库进行异步删除解析结果及关联记录。

Args:
    file_ids (list of str): 要删除的文件ID列表
"""
        document_list = self.update_file_status(file_ids, DocListManager.Status.deleting)
        self.update_kb_group(cond_file_ids=file_ids, new_status=DocListManager.Status.deleting)
        return document_list

delete_files_from_kb_group(file_ids, group) abstractmethod

从指定的知识库分组中删除文件。

Parameters:

  • file_ids (list of str) –

    要删除的文件ID列表。

  • group (str) –

    分组名称。

Source code in lazyllm/tools/rag/utils.py
    @abstractmethod
    def delete_files_from_kb_group(self, file_ids: List[str], group: str):
        """从指定的知识库分组中删除文件。

Args:
    file_ids (list of str): 要删除的文件ID列表。
    group (str): 分组名称。
"""
        pass

delete_unreferenced_doc() abstractmethod

删除数据库中标记为 "删除中" 且不再被引用的文档。 此方法从数据库中删除满足以下条件的文档: 1. 文档状态为 DocListManager.Status.deleting。 2. 文档的引用计数 (count) 为 0。

Source code in lazyllm/tools/rag/utils.py
    @abstractmethod
    def delete_unreferenced_doc(self):
        """删除数据库中标记为 "删除中" 且不再被引用的文档。
此方法从数据库中删除满足以下条件的文档:
1. 文档状态为 `DocListManager.Status.deleting`。
2. 文档的引用计数 (`count`) 为 0。
"""
        pass

fetch_docs_changed_meta(group) abstractmethod

获取指定组中元数据已更改的文档,并将其 new_meta 字段重置为 None。 此方法检索元数据已更改(即 new_meta 不为 None)的所有文档,基于提供的组名。检索后,会将这些文档的 new_meta 字段重置为 None

Parameters:

  • group (str) –

    用于过滤文档的组名。

Returns: List[DocMetaChangedRow]: 包含文档 doc_idnew_meta 字段的行列表,表示元数据已更改的文档。 说明: - 使用线程安全锁 (self._db_lock) 确保数据库访问安全。 - 方法通过 SQL JOIN 操作连接 KBDocumentKBGroupDocuments 表以检索相关行。 - 在获取数据后,将受影响行的 new_meta 字段更新为 None,并将更改提交到数据库。

Source code in lazyllm/tools/rag/utils.py
    @abstractmethod
    def fetch_docs_changed_meta(self, group: str) -> List[DocMetaChangedRow]:
        """获取指定组中元数据已更改的文档,并将其 `new_meta` 字段重置为 `None`。
此方法检索元数据已更改(即 `new_meta` 不为 `None`)的所有文档,基于提供的组名。检索后,会将这些文档的 `new_meta` 字段重置为 `None`。

Args:
    group (str): 用于过滤文档的组名。
**Returns:**
    List[DocMetaChangedRow]: 包含文档 `doc_id` 和 `new_meta` 字段的行列表,表示元数据已更改的文档。
说明:
    - 使用线程安全锁 (`self._db_lock`) 确保数据库访问安全。
    - 方法通过 SQL `JOIN` 操作连接 `KBDocument` 和 `KBGroupDocuments` 表以检索相关行。
    - 在获取数据后,将受影响行的 `new_meta` 字段更新为 `None`,并将更改提交到数据库。
"""
        pass

get_docs(doc_ids) abstractmethod

从数据库中检索类型为 KBDocument 的文档对象,基于提供的文档 ID 列表。

Parameters:

  • doc_ids (List[str]) –

    要获取的文档 ID 列表。

Returns: List[KBDocument]: 与提供的文档 ID 对应的 KBDocument 对象列表。如果没有找到文档,将返回空列表。 说明: - 使用线程安全锁 (self._db_lock) 确保数据库访问的安全性。 - 查询使用 SQL 的 IN 子句,通过 doc_id 字段进行过滤。 - 如果 doc_ids 为空,函数将直接返回空列表,而不会查询数据库。

Source code in lazyllm/tools/rag/utils.py
    @abstractmethod
    def get_docs(self, doc_ids: List[str]) -> List[KBDocument]:
        """从数据库中检索类型为 `KBDocument` 的文档对象,基于提供的文档 ID 列表。

Args:
    doc_ids (List[str]): 要获取的文档 ID 列表。
**Returns:**
    List[KBDocument]: 与提供的文档 ID 对应的 `KBDocument` 对象列表。如果没有找到文档,将返回空列表。
说明:
    - 使用线程安全锁 (`self._db_lock`) 确保数据库访问的安全性。
    - 查询使用 SQL 的 `IN` 子句,通过 `doc_id` 字段进行过滤。
    - 如果 `doc_ids` 为空,函数将直接返回空列表,而不会查询数据库。
"""
        pass

get_docs_need_reparse(group=None) abstractmethod

获取需要重新解析 (need_reparse=True)的指定组中的文档。 此方法检索标记为需要重新解析 (need_reparse=True) 的文档,基于提供的组名。仅包含状态为 successfailed 的文档。 Args: group (str): 用于过滤文档的组名。 Returns:: List[KBDocument]: 需要重新解析的 KBDocument 对象列表。 说明: - 使用线程安全锁 (self._db_lock) 确保多线程环境下的数据库访问安全。 - 查询通过 SQL JOIN 操作连接 KBDocumentKBGroupDocuments 表,并基于组名和重新解析状态进行过滤。 - 仅状态为 successfailedneed_reparse=True 的文档会被检索出来。

Source code in lazyllm/tools/rag/utils.py
    @abstractmethod
    def get_docs_need_reparse(self, group: Optional[str] = None) -> List[KBDocument]:
        """获取需要重新解析 (`need_reparse=True`)的指定组中的文档。
此方法检索标记为需要重新解析 (`need_reparse=True`) 的文档,基于提供的组名。仅包含状态为 `success` 或 `failed` 的文档。
Args:
    group (str): 用于过滤文档的组名。
**Returns:**:
    List[KBDocument]: 需要重新解析的 `KBDocument` 对象列表。
说明:
    - 使用线程安全锁 (`self._db_lock`) 确保多线程环境下的数据库访问安全。
    - 查询通过 SQL `JOIN` 操作连接 `KBDocument` 和 `KBGroupDocuments` 表,并基于组名和重新解析状态进行过滤。
    - 仅状态为 `success` 或 `failed` 且 `need_reparse=True` 的文档会被检索出来。
"""
        pass

get_existing_paths_by_pattern(file_path) abstractmethod

根据给定的模式,检索符合条件的文档路径。 此方法从数据库中获取所有符合提供的 SQL LIKE 模式的文档路径。 Args: pattern (str): 用于过滤文档路径的 SQL LIKE 模式。例如,%example% 匹配包含单词 "example" 的路径。 Returns:: List[str]: 符合给定模式的文档路径列表。如果没有匹配的路径,则返回空列表。 说明: - 使用线程安全锁 (self._db_lock) 确保多线程环境下的数据库访问安全。 - SQL 查询中的 LIKE 操作符用于对文档路径进行模式匹配。

Source code in lazyllm/tools/rag/utils.py
    @abstractmethod
    def get_existing_paths_by_pattern(self, file_path: str) -> List[str]:
        """根据给定的模式,检索符合条件的文档路径。
此方法从数据库中获取所有符合提供的 SQL `LIKE` 模式的文档路径。
Args:
    pattern (str): 用于过滤文档路径的 SQL `LIKE` 模式。例如,`%example%` 匹配包含单词 "example" 的路径。
**Returns:**:
    List[str]: 符合给定模式的文档路径列表。如果没有匹配的路径,则返回空列表。
说明:
    - 使用线程安全锁 (`self._db_lock`) 确保多线程环境下的数据库访问安全。
    - SQL 查询中的 `LIKE` 操作符用于对文档路径进行模式匹配。
"""
        pass

get_file_status(fileid) abstractmethod

获取指定文件的状态。

Parameters:

  • fileid (str) –

    文件ID。

Returns: - str: 文件的当前状态。

Source code in lazyllm/tools/rag/utils.py
    @abstractmethod
    def get_file_status(self, fileid: str):
        """获取指定文件的状态。

Args:
    fileid (str): 文件ID。

**Returns:**
- str: 文件的当前状态。
"""
        pass

init_tables()

确保数据库表默认分组存在。

Source code in lazyllm/tools/rag/utils.py
    def init_tables(self) -> 'DocListManager':
        """确保数据库表默认分组存在。
"""
        if not self.table_inited():
            self._init_tables()
        # in case of using after relase
        self.add_kb_group(DocListManager.DEFAULT_GROUP_NAME)
        return self

list_all_kb_group() abstractmethod

列出所有知识库分组的名称。

Returns: - list: 知识库分组名称列表。

Source code in lazyllm/tools/rag/utils.py
    @abstractmethod
    def list_all_kb_group(self):
        """列出所有知识库分组的名称。

**Returns:**
- list: 知识库分组名称列表。
"""
        pass

list_files(limit=None, details=False, status=Status.all, exclude_status=None) abstractmethod

documents 表中列出文件,并支持过滤、限制返回结果以及返回详细信息。 此方法根据指定的条件,从数据库中检索文件ID或详细文件信息。 参数: limit (Optional[int]): 返回的最大文件数量。如果为 None,则返回所有匹配的文件。 details (bool): 是否返回详细的文件信息(True)或仅返回文件ID(False)。 status (Union[str, List[str]]): 要包含的状态或状态列表,默认为所有状态。 exclude_status (Optional[Union[str, List[str]]]): 要排除的状态或状态列表,默认为 None。 返回值: List: 如果 details=False,则返回文件ID列表;如果 details=True,则返回详细文件行的列表。 说明: - 该方法根据 statusexclude_status 条件动态构造查询。 - 使用线程安全锁 (self._db_lock) 确保数据库访问安全。 - 如果指定了 limit,查询会附加 LIMIT 子句。

Source code in lazyllm/tools/rag/utils.py
    @abstractmethod
    def list_files(self, limit: Optional[int] = None, details: bool = False,
                   status: Union[str, List[str]] = Status.all,
                   exclude_status: Optional[Union[str, List[str]]] = None):
        """从 `documents` 表中列出文件,并支持过滤、限制返回结果以及返回详细信息。
此方法根据指定的条件,从数据库中检索文件ID或详细文件信息。
参数:
    limit (Optional[int]): 返回的最大文件数量。如果为 `None`,则返回所有匹配的文件。
    details (bool): 是否返回详细的文件信息(`True`)或仅返回文件ID(`False`)。
    status (Union[str, List[str]]): 要包含的状态或状态列表,默认为所有状态。
    exclude_status (Optional[Union[str, List[str]]]): 要排除的状态或状态列表,默认为 `None`。
返回值:
    List: 如果 `details=False`,则返回文件ID列表;如果 `details=True`,则返回详细文件行的列表。
说明:
    - 该方法根据 `status` 和 `exclude_status` 条件动态构造查询。
    - 使用线程安全锁 (`self._db_lock`) 确保数据库访问安全。
    - 如果指定了 `limit`,查询会附加 `LIMIT` 子句。
"""
        pass

list_kb_group_files(group=None, limit=None, details=False, status=Status.all, exclude_status=None, upload_status=Status.all, exclude_upload_status=None, need_reparse=False) abstractmethod

列出指定知识库组中的文件。

Parameters:

  • group (str, default: None ) –

    用于过滤文件的 KB 组名。默认为 None

  • limit (Optional[int], default: None ) –

    返回的最大文件数量。如果为 None,则返回所有匹配的文件。

  • details (bool, default: False ) –

    返回详细的文件信息或仅返回文件 ID 和路径。

  • status (Union[str, List[str]], default: all ) –

    包含在结果中的 KB 组状态或状态列表。默认为所有状态。

  • exclude_status (Optional[Union[str, List[str]], default: None ) –

    从结果中排除的 KB 组状态或状态列表。默认为 None

  • upload_status (Union[str, List[str]], default: all ) –

    包含在结果中的文档上传状态或状态列表。默认为所有状态。

  • exclude_upload_status (Optional[Union[str, List[str]], default: None ) –

    从结果中排除的文档上传状态或状态列表。默认为 None

  • need_reparse (Optional[bool], default: False ) –

    过滤需要重新解析的文件或不需要重新解析的文件。默认为 None

Returns:: List: 如果 details=False,返回包含 (doc_id, path) 的元组列表。 如果 details=True,返回包含附加元数据的详细行列表。 说明: - 方法根据提供的过滤条件动态构建 SQL 查询。 - 使用线程安全锁 (self._db_lock) 确保多线程环境下的数据库访问安全。 - 如果 statusupload_status 参数为列表,则会使用 SQL 的 IN 子句进行处理。

Source code in lazyllm/tools/rag/utils.py
    @abstractmethod
    def list_kb_group_files(self, group: str = None, limit: Optional[int] = None, details: bool = False,
                            status: Union[str, List[str]] = Status.all,
                            exclude_status: Optional[Union[str, List[str]]] = None,
                            upload_status: Union[str, List[str]] = Status.all,
                            exclude_upload_status: Optional[Union[str, List[str]]] = None,
                            need_reparse: Optional[bool] = False):
        """列出指定知识库组中的文件。

Args:
    group (str): 用于过滤文件的 KB 组名。默认为 `None`。
    limit (Optional[int]): 返回的最大文件数量。如果为 `None`,则返回所有匹配的文件。
    details (bool): 返回详细的文件信息或仅返回文件 ID 和路径。
    status (Union[str, List[str]]): 包含在结果中的 KB 组状态或状态列表。默认为所有状态。
    exclude_status (Optional[Union[str, List[str]]): 从结果中排除的 KB 组状态或状态列表。默认为 `None`。
    upload_status (Union[str, List[str]]): 包含在结果中的文档上传状态或状态列表。默认为所有状态。
    exclude_upload_status (Optional[Union[str, List[str]]): 从结果中排除的文档上传状态或状态列表。默认为 `None`。
    need_reparse (Optional[bool]): 过滤需要重新解析的文件或不需要重新解析的文件。默认为 `None`。
**Returns:**:
    List: 如果 `details=False`,返回包含 `(doc_id, path)` 的元组列表。
          如果 `details=True`,返回包含附加元数据的详细行列表。
说明:
    - 方法根据提供的过滤条件动态构建 SQL 查询。
    - 使用线程安全锁 (`self._db_lock`) 确保多线程环境下的数据库访问安全。
    - 如果 `status` 或 `upload_status` 参数为列表,则会使用 SQL 的 `IN` 子句进行处理。
"""
        pass

release() abstractmethod

释放当前管理器的资源。

Source code in lazyllm/tools/rag/utils.py
    @abstractmethod
    def release(self):
        """释放当前管理器的资源。

"""
        pass

set_docs_new_meta(doc_meta) abstractmethod

批量更新文档的元数据。

Parameters:

  • doc_meta (Dict[str, dict]) –

    文档ID到新元数据的映射字典。

Source code in lazyllm/tools/rag/utils.py
    @abstractmethod
    def set_docs_new_meta(self, doc_meta: Dict[str, dict]):
        """批量更新文档的元数据。

Args:
    doc_meta (Dict[str, dict]): 文档ID到新元数据的映射字典。

"""
        pass

table_inited() abstractmethod

检查数据库中的 documents 表是否已初始化。此方法在访问数据库时确保线程安全。 判断数据库中是否存在 documents 表。 返回值: bool: 如果 documents 表存在,返回 True;否则返回 False。 说明: - 使用线程安全锁 (self._db_lock) 确保对数据库的安全访问。 - 通过 self._db_path 连接 SQLite 数据库,并使用 check_same_thread 配置选项。 - 执行 SQL 查询:SELECT name FROM sqlite_master WHERE type='table' AND name='documents' 来检查表是否存在。

Source code in lazyllm/tools/rag/utils.py
    @abstractmethod
    def table_inited(self):
        """检查数据库中的 `documents` 表是否已初始化。此方法在访问数据库时确保线程安全。
判断数据库中是否存在 `documents` 表。
返回值:
    bool: 如果 `documents` 表存在,返回 `True`;否则返回 `False`。
说明:
    - 使用线程安全锁 (`self._db_lock`) 确保对数据库的安全访问。
    - 通过 `self._db_path` 连接 SQLite 数据库,并使用 `check_same_thread` 配置选项。
    - 执行 SQL 查询:`SELECT name FROM sqlite_master WHERE type='table' AND name='documents'` 来检查表是否存在。
"""
        pass

update_file_message(fileid, **kw) abstractmethod

更新指定文件的消息。

Parameters:

  • fileid (str) –

    文件ID。

  • **kw

    需要更新的其他键值对。

Source code in lazyllm/tools/rag/utils.py
    @abstractmethod
    def update_file_message(self, fileid: str, **kw):
        """更新指定文件的消息。

Args:
    fileid (str): 文件ID。
    **kw: 需要更新的其他键值对。
"""
        pass

update_file_status(file_ids, status, cond_status_list=None) abstractmethod

更新指定文件的状态。

Parameters:

  • file_ids (list of str) –

    更新状态的文件ID列表。

  • status (str) –

    目标状态。

  • cond_status_list(Union[None, (List[str]]) –

    限制只更新处于这些状态的文档

Source code in lazyllm/tools/rag/utils.py
    @abstractmethod
    def update_file_status(self, file_ids: List[str], status: str,
                           cond_status_list: Union[None, List[str]] = None) -> List[DocPartRow]:
        """更新指定文件的状态。

Args:
    file_ids (list of str): 更新状态的文件ID列表。
    status (str): 目标状态。
    cond_status_list(Union[None, List[str]]):限制只更新处于这些状态的文档
"""
        pass

update_kb_group(cond_file_ids, cond_group=None, cond_status_list=None, new_status=None, new_need_reparse=None) abstractmethod

更新指定知识库分组中的内容。

Parameters:

  • cond_file_ids (list of str) –

    过滤使用的文件ID列表,默认为None。

  • cond_group (str, default: None ) –

    过滤使用的知识库分组名称,默认为None。

  • cond_status_list (list of str, default: None ) –

    过滤使用的状态列表,默认为None。

  • new_status (str, default: None ) –

    新状态, 默认为None。

  • new_need_reparse ((bool, optinoal), default: None ) –

    新的是否需重解析标志。

Returns: - list: 得到更新的列表list of (doc_id, group_name)

Source code in lazyllm/tools/rag/utils.py
    @abstractmethod
    def update_kb_group(self, cond_file_ids: List[str], cond_group: Optional[str] = None,
                        cond_status_list: Optional[List[str]] = None, new_status: Optional[str] = None,
                        new_need_reparse: Optional[bool] = None) -> List[GroupDocPartRow]:
        """更新指定知识库分组中的内容。

Args:
    cond_file_ids (list of str, optional): 过滤使用的文件ID列表,默认为None。
    cond_group (str, optional): 过滤使用的知识库分组名称,默认为None。
    cond_status_list (list of str, optional): 过滤使用的状态列表,默认为None。
    new_status (str, optional): 新状态, 默认为None。
    new_need_reparse (bool, optinoal): 新的是否需重解析标志。

**Returns:**
- list: 得到更新的列表list of (doc_id, group_name)
"""
        pass

update_need_reparsing(doc_id, need_reparse) abstractmethod

更新 KBGroupDocuments 表中某个文档的 need_reparse 状态。 此方法设置指定文档的 need_reparse 标志,并可选限定到特定分组。 参数: doc_id (str): 要更新的文档ID。 need_reparse (bool): need_reparse 标志的新值。 group_name (Optional[str]): 如果提供,仅对指定分组应用更新;如果未提供,则对包含该文档的所有分组应用更新。 说明: - 使用线程安全锁 (self._db_lock) 确保数据库访问安全。 - group_name 参数允许将更新限定到特定分组;如果未提供,则更新应用于包含该文档的所有分组。 - 方法会立刻将更改提交到数据库。

Source code in lazyllm/tools/rag/utils.py
    @abstractmethod
    def update_need_reparsing(self, doc_id: str, need_reparse: bool):
        """更新 `KBGroupDocuments` 表中某个文档的 `need_reparse` 状态。
此方法设置指定文档的 `need_reparse` 标志,并可选限定到特定分组。
参数:
    doc_id (str): 要更新的文档ID。
    need_reparse (bool): `need_reparse` 标志的新值。
    group_name (Optional[str]): 如果提供,仅对指定分组应用更新;如果未提供,则对包含该文档的所有分组应用更新。
说明:
    - 使用线程安全锁 (`self._db_lock`) 确保数据库访问安全。
    - `group_name` 参数允许将更新限定到特定分组;如果未提供,则更新应用于包含该文档的所有分组。
    - 方法会立刻将更改提交到数据库。
"""
        pass

validate_paths(paths) abstractmethod

验证一组文件路径,以确保它们可以被正常处理。 此方法检查提供的路径是否是新的、已处理的或当前正在处理的,并确保处理文档时不会发生冲突。 参数: paths (List[str]): 要验证的文件路径列表。 返回值: Tuple[bool, str, List[bool]]: 返回一个元组,包括: - bool: 如果所有路径有效,则返回 True;否则返回 False。 - str: 表示成功或失败原因的消息。 - List[bool]: 一个布尔值列表,每个元素对应一个路径是否为新路径(True 表示新路径,False 表示已存在)。 说明: - 如果任何文档仍在处理中或需要重新解析,该方法会返回 False,并附带相应的错误消息。 - 方法通过数据库会话和线程安全锁 (self._db_lock) 检索文档状态信息。 - 不安全状态包括 workingwaiting

Source code in lazyllm/tools/rag/utils.py
    @abstractmethod
    def validate_paths(self, paths: List[str]) -> Tuple[bool, str, List[bool]]:
        """验证一组文件路径,以确保它们可以被正常处理。
此方法检查提供的路径是否是新的、已处理的或当前正在处理的,并确保处理文档时不会发生冲突。
参数:
    paths (List[str]): 要验证的文件路径列表。
返回值:
    Tuple[bool, str, List[bool]]: 返回一个元组,包括:
        - `bool`: 如果所有路径有效,则返回 `True`;否则返回 `False`。
        - `str`: 表示成功或失败原因的消息。
        - `List[bool]`: 一个布尔值列表,每个元素对应一个路径是否为新路径(`True` 表示新路径,`False` 表示已存在)。
说明:
    - 如果任何文档仍在处理中或需要重新解析,该方法会返回 `False`,并附带相应的错误消息。
    - 方法通过数据库会话和线程安全锁 (`self._db_lock`) 检索文档状态信息。
    - 不安全状态包括 `working` 和 `waiting`。

"""
        pass

lazyllm.tools.rag.global_metadata.GlobalMetadataDesc

用于描述全局元数据的说明符,包括其类型、可选的元素类型、默认值和大小限制。 class GlobalMetadataDesc 此类用于描述元数据的属性,例如类型、可选约束和默认值。支持标量和数组数据类型,并对某些类型指定特定的大小限制。 Args: data_type (int): 元数据的类型,以整数表示,代表不同的数据类型(例如 VARCHAR、ARRAY 等)。 element_type (Optional[int]): 如果 data_type 是数组,则表示数组中每个元素的类型。默认为 None。 default_value (Optional[Any]): 元数据的默认值。如果未提供,默认值为 None。 max_size (Optional[int]): 元数据的最大大小或长度。如果 data_typeVARCHARARRAY,则此属性为必填项。

Source code in lazyllm/tools/rag/global_metadata.py
class GlobalMetadataDesc:
    """用于描述全局元数据的说明符,包括其类型、可选的元素类型、默认值和大小限制。
`class GlobalMetadataDesc`
此类用于描述元数据的属性,例如类型、可选约束和默认值。支持标量和数组数据类型,并对某些类型指定特定的大小限制。
Args:
    data_type (int): 元数据的类型,以整数表示,代表不同的数据类型(例如 VARCHAR、ARRAY 等)。
    element_type (Optional[int]): 如果 `data_type` 是数组,则表示数组中每个元素的类型。默认为 `None`。
    default_value (Optional[Any]): 元数据的默认值。如果未提供,默认值为 `None`。
    max_size (Optional[int]): 元数据的最大大小或长度。如果 `data_type` 为 `VARCHAR` 或 `ARRAY`,则此属性为必填项。
"""
    # max_size MUST be set when data_type is DataType.VARCHAR or DataType.ARRAY
    def __init__(self, data_type: int, element_type: Optional[int] = None,
                 default_value: Optional[Any] = None, max_size: Optional[int] = None):
        self.data_type = data_type
        self.element_type = element_type
        self.default_value = default_value
        self.max_size = max_size

lazyllm.tools.rag.index_base.IndexBase

Bases: ABC

用于实现索引系统的抽象基类,支持更新、删除和查询文档节点。 class IndexBase(ABC) 此抽象基类定义了索引系统的接口,要求子类实现更新、删除和查询文档节点的方法。

Examples:

>>> from mymodule import IndexBase, DocNode
>>> class MyIndex(IndexBase):
...     def __init__(self):
...         self.nodes = []
...     def update(self, nodes):
...         self.nodes.extend(nodes)
...         print(f"Updated nodes: {nodes}")
...     def remove(self, uids, group_name=None):
...         self.nodes = [node for node in self.nodes if node.uid not in uids]
...         print(f"Removed nodes with uids: {uids}")
...     def query(self, *args, **kwargs):
...         print("Querying nodes...")
...         return self.nodes
>>> index = MyIndex()
>>> doc1 = DocNode(uid="1", content="Document 1")
>>> doc2 = DocNode(uid="2", content="Document 2")
>>> index.update([doc1, doc2])
Updated nodes: [DocNode(uid="1", content="Document 1"), DocNode(uid="2", content="Document 2")]
>>> index.query()
Querying nodes...
[DocNode(uid="1", content="Document 1"), DocNode(uid="2", content="Document 2")]
>>> index.remove(["1"])
Removed nodes with uids: ['1']
>>> index.query()
Querying nodes...
[DocNode(uid="2", content="Document 2")]
Source code in lazyllm/tools/rag/index_base.py
class IndexBase(ABC):
    """用于实现索引系统的抽象基类,支持更新、删除和查询文档节点。
`class IndexBase(ABC)`
此抽象基类定义了索引系统的接口,要求子类实现更新、删除和查询文档节点的方法。


Examples:
    >>> from mymodule import IndexBase, DocNode
    >>> class MyIndex(IndexBase):
    ...     def __init__(self):
    ...         self.nodes = []
    ...     def update(self, nodes):
    ...         self.nodes.extend(nodes)
    ...         print(f"Updated nodes: {nodes}")
    ...     def remove(self, uids, group_name=None):
    ...         self.nodes = [node for node in self.nodes if node.uid not in uids]
    ...         print(f"Removed nodes with uids: {uids}")
    ...     def query(self, *args, **kwargs):
    ...         print("Querying nodes...")
    ...         return self.nodes
    >>> index = MyIndex()
    >>> doc1 = DocNode(uid="1", content="Document 1")
    >>> doc2 = DocNode(uid="2", content="Document 2")
    >>> index.update([doc1, doc2])
    Updated nodes: [DocNode(uid="1", content="Document 1"), DocNode(uid="2", content="Document 2")]
    >>> index.query()
    Querying nodes...
    [DocNode(uid="1", content="Document 1"), DocNode(uid="2", content="Document 2")]
    >>> index.remove(["1"])
    Removed nodes with uids: ['1']
    >>> index.query()
    Querying nodes...
    [DocNode(uid="2", content="Document 2")]
    """
    # TODO(chenjiahao): change params `nodes` to `segments`, index should be able to handle segments
    @abstractmethod
    def update(self, nodes: List[DocNode]) -> None:
        """更新索引内容。

该方法接收一组文档节点对象,并将其添加或更新到索引结构中。通常用于增量构建或刷新索引。

Args:
    nodes (List[DocNode]): 需要更新的文档节点列表。
"""
        pass

    @abstractmethod
    def remove(self, uids: List[str], group_name: Optional[str] = None) -> None:
        """从索引中移除指定文档节点。

可根据唯一标识符列表删除索引中的文档节点,可选地指定组名称以限定范围。

Args:
    uids (List[str]): 需要移除的文档节点的唯一标识符列表。
    group_name (Optional[str]): 可选的组名称,用于限定要删除的范围。
"""
        pass

    @abstractmethod
    def query(self, *args, **kwargs) -> List[DocNode]:
        """执行索引查询。

根据传入的参数执行查询操作,返回匹配的文档节点列表。具体查询逻辑由实现类定义。

Returns:
    List[DocNode]: 查询结果的文档节点列表。
"""
        pass

query(*args, **kwargs) abstractmethod

执行索引查询。

根据传入的参数执行查询操作,返回匹配的文档节点列表。具体查询逻辑由实现类定义。

Returns:

  • List[DocNode]

    List[DocNode]: 查询结果的文档节点列表。

Source code in lazyllm/tools/rag/index_base.py
    @abstractmethod
    def query(self, *args, **kwargs) -> List[DocNode]:
        """执行索引查询。

根据传入的参数执行查询操作,返回匹配的文档节点列表。具体查询逻辑由实现类定义。

Returns:
    List[DocNode]: 查询结果的文档节点列表。
"""
        pass

remove(uids, group_name=None) abstractmethod

从索引中移除指定文档节点。

可根据唯一标识符列表删除索引中的文档节点,可选地指定组名称以限定范围。

Parameters:

  • uids (List[str]) –

    需要移除的文档节点的唯一标识符列表。

  • group_name (Optional[str], default: None ) –

    可选的组名称,用于限定要删除的范围。

Source code in lazyllm/tools/rag/index_base.py
    @abstractmethod
    def remove(self, uids: List[str], group_name: Optional[str] = None) -> None:
        """从索引中移除指定文档节点。

可根据唯一标识符列表删除索引中的文档节点,可选地指定组名称以限定范围。

Args:
    uids (List[str]): 需要移除的文档节点的唯一标识符列表。
    group_name (Optional[str]): 可选的组名称,用于限定要删除的范围。
"""
        pass

update(nodes) abstractmethod

更新索引内容。

该方法接收一组文档节点对象,并将其添加或更新到索引结构中。通常用于增量构建或刷新索引。

Parameters:

  • nodes (List[DocNode]) –

    需要更新的文档节点列表。

Source code in lazyllm/tools/rag/index_base.py
    @abstractmethod
    def update(self, nodes: List[DocNode]) -> None:
        """更新索引内容。

该方法接收一组文档节点对象,并将其添加或更新到索引结构中。通常用于增量构建或刷新索引。

Args:
    nodes (List[DocNode]): 需要更新的文档节点列表。
"""
        pass

lazyllm.tools.rag.IndexBase.update(nodes) abstractmethod

更新索引内容。

该方法接收一组文档节点对象,并将其添加或更新到索引结构中。通常用于增量构建或刷新索引。

Parameters:

  • nodes (List[DocNode]) –

    需要更新的文档节点列表。

Source code in lazyllm/tools/rag/index_base.py
    @abstractmethod
    def update(self, nodes: List[DocNode]) -> None:
        """更新索引内容。

该方法接收一组文档节点对象,并将其添加或更新到索引结构中。通常用于增量构建或刷新索引。

Args:
    nodes (List[DocNode]): 需要更新的文档节点列表。
"""
        pass

lazyllm.tools.rag.IndexBase.remove(uids, group_name=None) abstractmethod

从索引中移除指定文档节点。

可根据唯一标识符列表删除索引中的文档节点,可选地指定组名称以限定范围。

Parameters:

  • uids (List[str]) –

    需要移除的文档节点的唯一标识符列表。

  • group_name (Optional[str], default: None ) –

    可选的组名称,用于限定要删除的范围。

Source code in lazyllm/tools/rag/index_base.py
    @abstractmethod
    def remove(self, uids: List[str], group_name: Optional[str] = None) -> None:
        """从索引中移除指定文档节点。

可根据唯一标识符列表删除索引中的文档节点,可选地指定组名称以限定范围。

Args:
    uids (List[str]): 需要移除的文档节点的唯一标识符列表。
    group_name (Optional[str]): 可选的组名称,用于限定要删除的范围。
"""
        pass

lazyllm.tools.rag.IndexBase.query(*args, **kwargs) abstractmethod

执行索引查询。

根据传入的参数执行查询操作,返回匹配的文档节点列表。具体查询逻辑由实现类定义。

Returns:

  • List[DocNode]

    List[DocNode]: 查询结果的文档节点列表。

Source code in lazyllm/tools/rag/index_base.py
    @abstractmethod
    def query(self, *args, **kwargs) -> List[DocNode]:
        """执行索引查询。

根据传入的参数执行查询操作,返回匹配的文档节点列表。具体查询逻辑由实现类定义。

Returns:
    List[DocNode]: 查询结果的文档节点列表。
"""
        pass

lazyllm.tools.BaseEvaluator

Bases: ModuleBase

评估模块的抽象基类。

该类定义了模型评估的标准接口,支持并发处理、输入校验和评估结果的自动保存,同时内置了重试机制。

Parameters:

  • concurrency (int, default: 1 ) –

    评估过程中使用的并发线程数。

  • retry (int, default: 3 ) –

    每个样本的最大重试次数。

  • log_base_name (Optional[str], default: None ) –

    用于保存结果文件的日志文件名前缀(可选)。

Examples:

>>> from lazyllm.components import BaseEvaluator
>>> class SimpleAccuracyEvaluator(BaseEvaluator):
...     def _process_one_data_impl(self, data):
...         return {
...             "final_score": float(data["pred"] == data["label"])
...         }
>>> evaluator = SimpleAccuracyEvaluator()
>>> score = evaluator([
...     {"pred": "yes", "label": "yes"},
...     {"pred": "no", "label": "yes"}
... ])
>>> print(score)
... 0.5
Source code in lazyllm/tools/eval/eval_base.py
class BaseEvaluator(ModuleBase):
    """评估模块的抽象基类。

该类定义了模型评估的标准接口,支持并发处理、输入校验和评估结果的自动保存,同时内置了重试机制。

Args:
    concurrency (int): 评估过程中使用的并发线程数。
    retry (int): 每个样本的最大重试次数。
    log_base_name (Optional[str]): 用于保存结果文件的日志文件名前缀(可选)。


Examples:
    >>> from lazyllm.components import BaseEvaluator
    >>> class SimpleAccuracyEvaluator(BaseEvaluator):
    ...     def _process_one_data_impl(self, data):
    ...         return {
    ...             "final_score": float(data["pred"] == data["label"])
    ...         }
    >>> evaluator = SimpleAccuracyEvaluator()
    >>> score = evaluator([
    ...     {"pred": "yes", "label": "yes"},
    ...     {"pred": "no", "label": "yes"}
    ... ])
    >>> print(score)
    ... 0.5
    """
    def __init__(self, concurrency=1, retry=3, log_base_name=None):
        super().__init__()
        self._concurrency = concurrency
        self._retry = retry
        self._lock = threading.Lock()
        self._warp = warp(self.process_one_data, _concurrent=self._concurrency)
        self._necessary_keys = []

    def _execute_with_retries(self, input_data, func, result_validator=None, post_processor=None):
        for attempt in range(1, self._retry + 1):
            try:
                result = func(input_data)
                if post_processor is not None:
                    result = post_processor(result)
                if result_validator is None or result_validator(result):
                    return result
                lazyllm.LOG.warning(f"Validation failed on attempt {attempt}/{self._retry}")
            except Exception as e:
                lazyllm.LOG.error(f"Attempt {attempt}/{self._retry} failed: {str(e)}")
        lazyllm.LOG.error(f"All {self._retry} attempts exhausted")
        return ''

    def forward(self, data):
        if not data:
            lazyllm.LOG.warning("Empty input data received")
            return 0.0

        with tqdm(total=len(data), desc=self.__class__.__name__.title()) as progress_bar:
            results = self.batch_process(data, progress_bar)

        if not results:
            return 0.0

        total_score = sum(item.get('final_score', 0) for item in results)
        return total_score / len(results)

    def process_one_data(self, data, progress_bar=None):
        res = self._process_one_data_impl(data)
        if progress_bar is not None:
            with self._lock:
                progress_bar.update(1)
        return res

    @abc.abstractmethod
    def _process_one_data_impl(self, data):
        pass

    def validate_inputs_key(self, data):
        if not isinstance(data, list):
            raise RuntimeError(f"The data should be a list, but got {type(data)}")
        for i, item in enumerate(data):
            if not isinstance(item, dict):
                raise RuntimeError(f"The item at index {i} should be a dict, but got {type(item)}")
            missing_keys = [key for key in self._necessary_keys if key not in item]
            if missing_keys:
                raise RuntimeError(
                    f"The dict at index {i} should contain "
                    f"keys: {self._necessary_keys}, but cannot find: {missing_keys}")

    def batch_process(self, data, progress_bar):
        self.validate_inputs_key(data)
        results = self._warp(data, progress_bar=progress_bar)
        self.save_res(results)
        return results

    def save_res(self, data, eval_res_save_name=None):
        save_dir = lazyllm.config['eval_result_dir']
        os.makedirs(save_dir, exist_ok=True)

        filename = eval_res_save_name or self.__class__.__name__
        timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
        save_path = os.path.join(save_dir, f"{filename}_{timestamp}.json")
        try:
            with open(save_path, 'w') as file:
                json.dump(data, file, ensure_ascii=False, indent=4)
        except Exception as e:
            lazyllm.LOG.error(f"Dump Json error: {e}")

lazyllm.tools.ResponseRelevancy

Bases: BaseEvaluator

用于评估用户问题与模型生成问题之间语义相关性的指标类。

该评估器使用语言模型根据回答生成问题,并通过 Embedding 与余弦相似度度量其与原始问题之间的相关性。

Parameters:

  • llm (ModuleBase) –

    用于根据回答生成问题的语言模型模块。

  • embedding (ModuleBase) –

    用于编码问题向量的嵌入模块。

  • prompt ((str, 可选), default: None ) –

    自定义的生成提示词,若不提供将使用默认提示。

  • prompt_lang (str, default: 'en' ) –

    默认提示词的语言,可选 'en'(默认)或 'zh'

  • num_infer_questions (int, default: 3 ) –

    每条数据生成和评估的问题数量。

  • retry (int, default: 3 ) –

    失败时的重试次数。

  • concurrency (int, default: 1 ) –

    并发评估的数量。

Examples:

>>> from lazyllm.components import ResponseRelevancy
>>> relevancy = ResponseRelevancy(
...     llm=YourLLM(),
...     embedding=YourEmbedding(),
...     prompt_lang="en",
...     num_infer_questions=3
... )
>>> result = relevancy([
...     {"question": "What is the capital of France?", "answer": "Paris is the capital city of France."}
... ])
>>> print(result)
... 0.95  # (a float score between 0 and 1)
Source code in lazyllm/tools/eval/rag_generator_metrics.py
class ResponseRelevancy(BaseEvaluator):
    """用于评估用户问题与模型生成问题之间语义相关性的指标类。

该评估器使用语言模型根据回答生成问题,并通过 Embedding 与余弦相似度度量其与原始问题之间的相关性。


Args:
    llm (ModuleBase): 用于根据回答生成问题的语言模型模块。
    embedding (ModuleBase): 用于编码问题向量的嵌入模块。
    prompt (str, 可选): 自定义的生成提示词,若不提供将使用默认提示。
    prompt_lang (str): 默认提示词的语言,可选 `'en'`(默认)或 `'zh'`。
    num_infer_questions (int): 每条数据生成和评估的问题数量。
    retry (int): 失败时的重试次数。
    concurrency (int): 并发评估的数量。


Examples:
    >>> from lazyllm.components import ResponseRelevancy
    >>> relevancy = ResponseRelevancy(
    ...     llm=YourLLM(),
    ...     embedding=YourEmbedding(),
    ...     prompt_lang="en",
    ...     num_infer_questions=3
    ... )
    >>> result = relevancy([
    ...     {"question": "What is the capital of France?", "answer": "Paris is the capital city of France."}
    ... ])
    >>> print(result)
    ... 0.95  # (a float score between 0 and 1)
    """
    _default_generate_prompt_en = (
        'Please generate the most likely question based on '
        'the input, keeping it concise and to the point.')
    _default_generate_prompt_zh = ('请根据输入生成最可能的一个问题,保持简洁明了。')

    def __init__(self, llm, embedding, prompt=None, prompt_lang='en',
                 num_infer_questions=3, retry=3, concurrency=1):
        super().__init__(concurrency, retry)
        if prompt_lang.strip().lower() == 'zh':
            default_prompt = self._default_generate_prompt_zh
        else:
            default_prompt = self._default_generate_prompt_en
        self._llm = llm.prompt(prompt or default_prompt)
        self._embedding = embedding
        self._num_infer_questions = num_infer_questions
        self._necessary_keys = ['question', 'answer']

    def _cosine(self, x, y):
        product = np.dot(x, y)
        norm = np.linalg.norm(x) * np.linalg.norm(y)
        raw_cosine = product / norm if norm != 0 else 0.0
        return max(0.0, min(raw_cosine, 1.0))

    def _process_one_data_impl(self, data):
        one_total_score = 0
        res = copy.deepcopy(data)
        res['infer_questions'] = []
        for _ in range(self._num_infer_questions):
            # Generate Questions:
            guess_question = self._execute_with_retries(data['answer'], self._llm)

            # Calculate Similarity:
            try:
                if isinstance(self._embedding, lazyllm.module.OnlineEmbeddingModuleBase):
                    vector1 = self._embedding(guess_question)
                    vector2 = self._embedding(data['question'])
                else:
                    vector1, vector2 = json.loads(self._embedding([guess_question, data['question']]))
                score = self._cosine(vector1, vector2)
            except Exception as e:
                lazyllm.LOG.error(f'Eval-Infer Error: {e}')
                score = 0
            res['infer_questions'].append({
                'question': guess_question,
                'score': round(score, 4)
            })
            one_total_score += score
        res['final_score'] = round(one_total_score / self._num_infer_questions, 4)
        return res

lazyllm.tools.Faithfulness

Bases: BaseEvaluator

评估回答与上下文之间事实一致性的指标类。

该评估器首先使用语言模型将答案拆分为独立事实句,然后基于上下文对每条句子进行支持性判断(0或1分),最终取平均值作为总体一致性分数。

Parameters:

  • llm (ModuleBase) –

    同时用于生成句子与进行评估的语言模型模块。

  • generate_prompt ((str, 可选), default: None ) –

    用于将答案转换为事实句的自定义提示词。

  • eval_prompt ((str, 可选), default: None ) –

    用于评估句子与上下文匹配度的提示词。

  • prompt_lang (str, default: 'en' ) –

    默认提示词的语言,可选 'en' 或 'zh'。

  • retry (int, default: 3 ) –

    生成或评估失败时的最大重试次数。

  • concurrency (int, default: 1 ) –

    并发评估的数据条数。

Examples:

>>> from lazyllm.components import Faithfulness
>>> evaluator = Faithfulness(llm=YourLLM(), prompt_lang="en")
>>> data = {
...     "question": "What is the role of ATP in cells?",
...     "answer": "ATP stores energy and transfers it within cells.",
...     "context": "ATP is the energy currency of the cell. It provides energy for many biochemical reactions."
... }
>>> result = evaluator([data])
>>> print(result)
... 1.0  # Average binary score of all factual statements
Source code in lazyllm/tools/eval/rag_generator_metrics.py
class Faithfulness(BaseEvaluator):
    """评估回答与上下文之间事实一致性的指标类。

该评估器首先使用语言模型将答案拆分为独立事实句,然后基于上下文对每条句子进行支持性判断(0或1分),最终取平均值作为总体一致性分数。


Args:
    llm (ModuleBase): 同时用于生成句子与进行评估的语言模型模块。
    generate_prompt (str, 可选): 用于将答案转换为事实句的自定义提示词。
    eval_prompt (str, 可选): 用于评估句子与上下文匹配度的提示词。
    prompt_lang (str): 默认提示词的语言,可选 'en' 或 'zh'。
    retry (int): 生成或评估失败时的最大重试次数。
    concurrency (int): 并发评估的数据条数。


Examples:
    >>> from lazyllm.components import Faithfulness
    >>> evaluator = Faithfulness(llm=YourLLM(), prompt_lang="en")
    >>> data = {
    ...     "question": "What is the role of ATP in cells?",
    ...     "answer": "ATP stores energy and transfers it within cells.",
    ...     "context": "ATP is the energy currency of the cell. It provides energy for many biochemical reactions."
    ... }
    >>> result = evaluator([data])
    >>> print(result)
    ... 1.0  # Average binary score of all factual statements
    """
    _default_generate_prompt_en = (
        '[Task Description]\n'
        'Split the answer into independent factual statements using "|||" as '
        'the exclusive separator, following these rules:\n'
        '1. Each statement must be a complete sentence ending with proper punctuation\n'
        '2. Never use line breaks or other symbols as separators\n'
        '3. Statements containing "|||" must be rephrased\n'
        '4. Each statement must be clear, pronoun-free.\n'
        '[Output Format]\n'
        'statement_1|||statement_2|||statement_3\n'
        '[Example Input]\n'
        'Q: How does photosynthesis work?\n'
        'A: The process requires sunlight, then chlorophyll absorbs light energy. '
        'It converts water and CO2 into glucose.\n'
        '[Example Output]\n'
        'Photosynthesis requires sunlight.|||Chlorophyll absorbs light energy.'
        '|||Chlorophyll converts water and CO2 into glucose.\n'
    )
    _default_eval_prompt_en = (
        '[Task Description]\n'
        'Evaluate each "|||"-separated statement against provided context using binary scoring:\n'
        'Fully supported by context: 1\n'
        'Unsupported/contradictory: 0\n'
        '[Output Requirements]\n'
        '1. JSON format with array of objects\n'
        '2. Each object contains:\n'
        '    - "statement": Original text\n'
        '    - "score": 1 or 0\n'
        '3. Wrap output in ```json code block\n'
        '[Example Input]\n'
        'Context: Photosynthesis occurs in chloroplasts. Light reactions produce ATP using sunlight. '
        'Calvin cycle fixes CO2 into sugars.\n'
        'Statements: Photosynthesis requires sunlight.|||Chlorophyll absorbs light energy.'
        '|||Chlorophyll converts water and CO2 into glucose.\n'
        '[Example Output]\n'
        '[{"statement": "Photosynthesis requires sunlight.","score": 1},'
        '{"statement": "Chlorophyll absorbs light energy.", "score": 1},'
        '{"statement": "Chlorophyll converts water and CO2 into glucose.","score": 0}]\n'
    )
    _default_generate_prompt_zh = (
        '[任务描述]\n'
        '使用"|||"作为唯一分隔符,将答案分割成独立的事实陈述,遵循以下规则:\n'
        '1. 每个陈述必须是完整的句子,并以适当的标点结束\n'
        '2. 不要使用换行符或其他符号作为分隔符\n'
        '3. 包含"|||"的陈述必须重新措辞\n'
        '4. 每个陈述必须清晰,不包含代词。\n'
        '[输出格式]\n'
        'statement_1|||statement_2|||statement_3\n'
        '[示例输入]\n'
        'Q: 光合作用是如何工作的?\n'
        'A: 该过程需要阳光,然后叶绿素吸收光能。它将水和CO2转化为葡萄糖。\n'
        '[示例输出]\n'
        '光合作用需要阳光。|||叶绿素吸收光能。|||叶绿素将水和CO2转化为葡萄糖。\n'
    )
    _default_eval_prompt_zh = (
        '[任务描述]\n'
        '使用二进制评分对每个"|||"分隔的陈述与提供的内容进行评估:\n'
        '完全由内容支持:1\n'
        '不支持/矛盾:0\n'
        '[输出要求]\n'
        '1. JSON格式,包含对象数组\n'
        '2. 每个对象包含:\n'
        '    - "statement": 原始文本\n'
        '    - "score": 1或0\n'
        '3. 将输出包裹在```json代码块中\n'
        '[示例输入]\n'
        'Context: 光合作用发生在叶绿体中。光反应利用阳光产生ATP。卡尔文循环将CO2固定成糖。\n'
        'Statements: 光合作用需要阳光。|||叶绿素吸收光能。|||叶绿素将水和CO2转化为葡萄糖。\n'
        '[示例输出]\n'
        '[{"statement": "光合作用需要阳光。","score": 1},'
        '{"statement": "叶绿素吸收光能。", "score": 1},'
        '{"statement": "叶绿素将水和CO2转化为葡萄糖。","score": 0}]\n'
    )

    def __init__(self, llm, generate_prompt=None, eval_prompt=None, prompt_lang='en', retry=3, concurrency=1):
        super().__init__(concurrency, retry)
        self._base_llm = llm
        if prompt_lang == 'zh':
            default_generate_prompt = generate_prompt or self._default_generate_prompt_zh
            default_eval_prompt = eval_prompt or self._default_eval_prompt_zh
        else:
            default_generate_prompt = generate_prompt or self._default_generate_prompt_en
            default_eval_prompt = eval_prompt or self._default_eval_prompt_en
        self._build_llms(self._base_llm, default_generate_prompt, default_eval_prompt)
        self._necessary_keys = ['question', 'answer', 'context']

    def _build_llms(self, base_llm, generate_prompt, eval_prompt):
        self._gene_llm = base_llm.share(prompt=generate_prompt)
        self._eval_llm = base_llm.share(prompt=eval_prompt).formatter(JsonFormatter())

    def _validate_eval_result(self, result):
        return (
            isinstance(result, list)
            and len(result) > 0
            and all(isinstance(i, dict) and 'score' in i for i in result)
        )

    def _post_processor(self, eval_result):
        if isinstance(eval_result, dict):
            eval_result = [eval_result]
        return eval_result

    def _process_one_data_impl(self, data):
        res = copy.deepcopy(data)
        # Generate Statements:
        query1 = f'Q: {data["question"]}\nA: {data["answer"]}'
        statements = self._execute_with_retries(query1, self._gene_llm)
        res['statements'] = statements

        # Eval Statements in Context:
        query2 = f'Context: {data["context"]}\nStatements: {statements}'
        eval_result = self._execute_with_retries(
            query2, self._eval_llm, self._validate_eval_result, self._post_processor)
        if not self._validate_eval_result(eval_result):
            lazyllm.LOG.error("Invalid evaluation result format")
            res.update({'scores': [], 'final_score': 0.0})
            return res

        total_score = sum(
            int(entry.get('score', 0)) if entry.get('score') in (0, 1) else 0
            for entry in eval_result
        )
        res['scores'] = eval_result
        res['final_score'] = round(total_score / len(eval_result), 4) if eval_result else 0.0
        return res

lazyllm.tools.LLMContextRecall

Bases: BaseEvaluator

Source code in lazyllm/tools/eval/rag_retriever_metrics.py
class LLMContextRecall(BaseEvaluator):
    _default_eval_prompt_en = (
        '[Task Description]\n'
        'Given a context, and an answer, analyze each sentence in the answer and '
        'classify if the sentence can be attributed to the given context or not:\n'
        'Fully supported by context: 1\n'
        'Unsupported/contradictory: 0\n'
        '[Output Requirements]\n'
        '1. JSON format with array of objects\n'
        '2. Each object contains:\n'
        '    - "statement": Original text\n'
        '    - "reason": the reason why it is scored 1/0\n'
        '    - "score": 1 or 0\n'
        '3. Wrap output in ```json code block\n'
        '[Example Input]\n'
        'Question: What is Photosynthesis?'
        'Context: Photosynthesis occurs in chloroplasts. Light reactions produce ATP using sunlight.\n'
        'Statements: Photosynthesis was discovered in 1780s. It occurs in chloroplasts and produce ATP using sunlight.\n'
        '[Example Output]\n'
        '[{"statement": "Photosynthesis was discovered in 1780s", '
        '"reason": "The time when photosynthesis discovered was not mentioned in the given context","score": 0},'
        ' {"statement": "It occurs in chloroplasts and produce ATP using sunlight.", '
        '"reason": "The exact sentence is present in the given context", "score": 1}]\n'
    )
    _default_eval_prompt_zh = (
        '[任务描述]\n'
        '给定一个上下文和一个答案,分析答案中的每个句子并判断该句子是否可以归因于给定的上下文:\n'
        '完全受上下文支持:1\n'
        '不支持/矛盾:0\n'
        '[输出要求]\n'
        '1. 带有对象数组的 JSON 格式\n'
        '2. 每个对象包含:\n'
        ' - "statement":原始文本\n'
        ' - "reason":评分原因\n'
        ' - "score":1 或 0\n'
        '3. 将输出包裹在 ```json 代码块中\n'
        '[示例输入]\n'
        'question:什么是光合作用?'
        'context:光合作用发生在叶绿体中,利用阳光产生 ATP。\n'
        'statement:光合作用于 1780 年代被发现。光合作用发生在叶绿体中,并利用阳光产生 ATP。\n'
        '[示例输出]\n'
        '[{"statement": "光合作用于 1780 年代被发现", "reason": "给定上下文中未提及发现光合作用被发现的时间","score": 0},'
        ' {"statement": "光合作用发生在叶绿体中,并利用阳光产生 ATP。", "reason": "给定上下文中存在确切的句子", "score": 1}]\n'
    )

    def __init__(self, llm, eval_prompt=None, prompt_lang='en', retry=3, concurrency=1):
        super().__init__(concurrency, retry)
        if prompt_lang == 'zh':
            default_eval_prompt = eval_prompt or self._default_eval_prompt_zh
        else:
            default_eval_prompt = eval_prompt or self._default_eval_prompt_en
        self._llm = llm.prompt(default_eval_prompt).formatter(JsonFormatter()) if llm else None
        self._necessary_keys = ['question', 'answer', 'context_retrieved']

    def _validate_eval_result(self, result):
        return (
            isinstance(result, list)
            and len(result) > 0
            and all(isinstance(i, dict) and 'score' in i for i in result)
        )

    def _post_processor(self, eval_result):
        if isinstance(eval_result, dict):
            eval_result = [eval_result]
        return eval_result

    def _process_one_data_impl(self, data):
        res = copy.deepcopy(data)
        context = "\n".join(data['context_retrieved'])

        query = f'question: {data["question"]}\ncontext: {context}\nstatement: {data["answer"]}'
        eval_result = self._execute_with_retries(
            query, self._llm, self._validate_eval_result, self._post_processor)
        scores = [result["score"] for result in eval_result]

        res['final_score'] = round(sum(scores) / len(scores), 4) if scores else 0.0
        return res

lazyllm.tools.NonLLMContextRecall

Bases: BaseEvaluator

Source code in lazyllm/tools/eval/rag_retriever_metrics.py
class NonLLMContextRecall(BaseEvaluator):
    def __init__(self, th=0.5, binary=True, retry=3, concurrency=1):
        super().__init__(concurrency, retry)
        self._binary = binary
        self._threshold = th
        self._necessary_keys = ['context_retrieved', 'context_reference']

    def _calc_levenshtein_distance(self, reference, context):
        return 1 - rapidfuzz.distance.Levenshtein.normalized_distance(reference, context)

    def _calc_context_recall(self, data):
        contexts, reference = data["context"], data["reference"]
        scores = []
        for context in contexts:
            score = self._calc_levenshtein_distance(reference, context)
            scores.append(score)
        return scores

    def _compute_scores(self, scores):
        binary_scores = [1 if score > self._threshold else 0 for score in scores]

        if self._binary:
            return 1.0 if sum(binary_scores) > 0 else 0.0
        if len(binary_scores) > 0:
            return sum(binary_scores) / len(binary_scores)
        return 0

    def _process_one_data_impl(self, data):
        res = copy.deepcopy(data)
        scores = []
        for reference in data['context_reference']:
            input_data = {'context': data['context_retrieved'], 'reference': reference}
            eval_result = self._execute_with_retries(input_data, self._calc_context_recall)
            scores.append(self._compute_scores(eval_result))

        res['final_score'] = round(sum(scores) / len(scores), 4) if scores else 0.0
        return res

lazyllm.tools.ContextRelevance

Bases: BaseEvaluator

Source code in lazyllm/tools/eval/rag_retriever_metrics.py
class ContextRelevance(BaseEvaluator):
    def __init__(self, splitter="。", retry=3, concurrency=1):
        super().__init__(concurrency, retry)
        self._splitter = splitter
        self._necessary_keys = ['context_retrieved', 'context_reference']

    def _calc_context_relevance(self, data):
        sentences_retrieved, sentences_reference = data["context"], data["reference"]
        scores = [0] * len(sentences_retrieved)
        for i, sentence in enumerate(sentences_retrieved):
            if sentence in sentences_reference:
                scores[i] = 1
        return scores

    def _paragraphs_to_sentences(self, paragraphs):
        sentences = []
        pattern = rf'{re.escape(self._splitter)}+'
        for paragraph in paragraphs:
            sentences.extend([s.strip() for s in re.split(pattern, paragraph) if s.strip()])
        return sentences

    def _process_one_data_impl(self, data):
        res = copy.deepcopy(data)
        retrieved = self._paragraphs_to_sentences(data["context_retrieved"])
        reference = self._paragraphs_to_sentences(data["context_reference"])

        input_data = {'context': retrieved, 'reference': reference}
        eval_result = self._execute_with_retries(input_data, self._calc_context_relevance)
        total_score = sum(eval_result)

        res['final_score'] = round(total_score / len(eval_result), 4) if eval_result else 0.0
        return res

lazyllm.tools.HttpRequest

Bases: ModuleBase

通用 HTTP 请求执行器。

该类用于构建并发送 HTTP 请求,支持变量替换、API Key 注入、JSON 或表单编码、文件类型响应识别等功能。

Parameters:

  • method (str) –

    HTTP 方法,如 'GET'、'POST' 等。

  • url (str) –

    请求目标的 URL。

  • api_key (str) –

    可选的 API Key,会被加入请求参数。

  • headers (dict) –

    HTTP 请求头。

  • params (dict) –

    URL 查询参数。

  • body (Union[str, dict]) –

    请求体,支持字符串或 JSON 字典格式。

  • timeout (int, default: 10 ) –

    请求超时时间(秒)。

  • proxies (dict, default: None ) –

    可选的代理设置。

Examples:

>>> from lazyllm.components import HttpRequest
>>> request = HttpRequest(
...     method="GET",
...     url="https://api.github.com/repos/openai/openai-python",
...     api_key="",
...     headers={"Accept": "application/json"},
...     params={},
...     body=None
... )
>>> result = request()
>>> print(result["status_code"])
... 200
>>> print(result["content"][:100])
... '{"id":123456,"name":"openai-python", ...}'
Source code in lazyllm/tools/http_request/http_request.py
class HttpRequest(ModuleBase):
    """通用 HTTP 请求执行器。

该类用于构建并发送 HTTP 请求,支持变量替换、API Key 注入、JSON 或表单编码、文件类型响应识别等功能。

Args:
    method (str): HTTP 方法,如 'GET'、'POST' 等。
    url (str): 请求目标的 URL。
    api_key (str): 可选的 API Key,会被加入请求参数。
    headers (dict): HTTP 请求头。
    params (dict): URL 查询参数。
    body (Union[str, dict]): 请求体,支持字符串或 JSON 字典格式。
    timeout (int): 请求超时时间(秒)。
    proxies (dict, optional): 可选的代理设置。


Examples:
    >>> from lazyllm.components import HttpRequest
    >>> request = HttpRequest(
    ...     method="GET",
    ...     url="https://api.github.com/repos/openai/openai-python",
    ...     api_key="",
    ...     headers={"Accept": "application/json"},
    ...     params={},
    ...     body=None
    ... )
    >>> result = request()
    >>> print(result["status_code"])
    ... 200
    >>> print(result["content"][:100])
    ... '{"id":123456,"name":"openai-python", ...}'
    """
    def __init__(self, method, url, api_key, headers, params, body, timeout=10, proxies=None):
        super().__init__()
        if not url:
            return

        self._method = method
        self._url = url
        self._api_key = api_key
        self._headers = headers
        self._params = params
        self._body = body
        self._timeout = timeout
        self._proxies = proxies

    def _process_api_key(self, headers, params):
        if self._api_key and self._api_key != '':
            params = params or {}
            params['api_key'] = self._api_key
        return headers, params

    def forward(self, *args, **kwargs):
        def _map_input(target_str):
            if not isinstance(target_str, str):
                return target_str

            # TODO: replacements could be more complex to create.
            replacements = {**kwargs, **(args[0] if args and isinstance(args[0], dict) else {})}
            if not replacements:
                return target_str

            pattern = r"\{\{([^}]+)\}\}"
            matches = re.findall(pattern, target_str)
            for match in matches:
                replacement = replacements.get(match)
                if replacement is not None:
                    if "{{" + match + "}}" == target_str:
                        return replacement
                    target_str = re.sub(r"\{\{" + re.escape(match) + r"\}\}", replacement, target_str)

            return target_str

        url = _map_input(self._url)
        params = {key: _map_input(value) for key, value in self._params.items()} if self._params else None
        headers = {key: _map_input(value) for key, value in self._headers.items()} if self._headers else None
        headers, params = self._process_api_key(headers, params)
        if isinstance(headers, dict) and headers.get("Content-Type") == "application/json":
            try:
                body = json.loads(self._body) if isinstance(self._body, str) else self._body
                body = {k: _map_input(v) for k, v in body.items()}

                http_response = httpx.request(method=self._method, url=url, headers=headers,
                                              params=params, json=body, timeout=self._timeout,
                                              proxies=self._proxies)
            except json.JSONDecodeError:
                raise ValueError(f"Invalid JSON format: {self._body}")
        else:
            body = (json.dumps({k: _map_input(v) for k, v in self._body.items()})
                    if isinstance(self._body, dict) else _map_input(self._body))

            http_response = httpx.request(method=self._method, url=url, headers=headers,
                                          params=params, data=body, timeout=self._timeout,
                                          proxies=self._proxies)

        response = HttpExecutorResponse(http_response)

        _, file_binary = response.extract_file()

        outputs = {
            'status_code': response.status_code,
            'content': response.content if len(file_binary) == 0 else None,
            'headers': response.headers,
            'file': file_binary
        }
        return outputs

lazyllm.tools.JobDescription

Bases: BaseModel

模型部署任务描述的数据结构。

用于创建模型推理任务时指定部署配置,包括模型名称与所需 GPU 数量。

Parameters:

  • deploy_model (str) –

    要部署的模型名称,默认为 "qwen1.5-0.5b-chat"。

  • num_gpus (int) –

    所需的 GPU 数量,默认为 1。

Examples:

>>> from lazyllm.components import JobDescription
>>> job = JobDescription(deploy_model="deepseek-coder", num_gpus=2)
>>> print(job.dict())
... {'deploy_model': 'deepseek-coder', 'num_gpus': 2}
Source code in lazyllm/tools/infer_service/serve.py
class JobDescription(BaseModel):
    """模型部署任务描述的数据结构。

用于创建模型推理任务时指定部署配置,包括模型名称与所需 GPU 数量。

Args:
    deploy_model (str): 要部署的模型名称,默认为 "qwen1.5-0.5b-chat"。
    num_gpus (int): 所需的 GPU 数量,默认为 1。


Examples:
    >>> from lazyllm.components import JobDescription
    >>> job = JobDescription(deploy_model="deepseek-coder", num_gpus=2)
    >>> print(job.dict())
    ... {'deploy_model': 'deepseek-coder', 'num_gpus': 2}
    """
    deploy_model: str = Field(default='qwen1.5-0.5b-chat')
    num_gpus: int = Field(default=1)

lazyllm.tools.DBManager

Bases: ABC, ModuleBase

数据库管理器的抽象基类。

该类定义了构建数据库连接器的通用接口,包括 execute_query 抽象方法和 desc 描述属性。

Parameters:

  • db_type (str) –

    数据库类型标识符,例如 'mysql'、'mongodb'。

Examples:

>>> from lazyllm.components import DBManager
>>> class DummyDB(DBManager):
...     def __init__(self):
...         super().__init__(db_type="dummy")
...     def execute_query(self, statement):
...         return f"Executed: {statement}"
...     @property
...     def desc(self):
...         return "Dummy database for testing."
>>> db = DummyDB()
>>> print(db("SELECT * FROM test"))
... Executed: SELECT * FROM test
Source code in lazyllm/tools/sql/db_manager.py
class DBManager(ABC, ModuleBase, metaclass=CommonMeta):
    """数据库管理器的抽象基类。

该类定义了构建数据库连接器的通用接口,包括 `execute_query` 抽象方法和 `desc` 描述属性。

Args:
    db_type (str): 数据库类型标识符,例如 'mysql'、'mongodb'。


Examples:
    >>> from lazyllm.components import DBManager
    >>> class DummyDB(DBManager):
    ...     def __init__(self):
    ...         super().__init__(db_type="dummy")
    ...     def execute_query(self, statement):
    ...         return f"Executed: {statement}"
    ...     @property
    ...     def desc(self):
    ...         return "Dummy database for testing."
    >>> db = DummyDB()
    >>> print(db("SELECT * FROM test"))
    ... Executed: SELECT * FROM test
    """

    def __init__(self, db_type: str):
        ModuleBase.__init__(self)
        self._db_type = db_type
        self._desc = None

    @abstractmethod
    def execute_query(self, statement) -> str:
        """执行数据库查询语句的抽象方法。此方法需要由具体的数据库管理器子类实现,用于执行各种数据库操作。

Args:
    statement: 要执行的数据库查询语句,可以是 SQL 语句或其他数据库特定的查询语言

此方法的特点:

- **抽象方法**: 需要在子类中实现具体的数据库操作逻辑
- **统一接口**: 为不同的数据库类型提供统一的查询接口
- **错误处理**: 子类实现应该包含适当的错误处理和状态报告
- **结果格式化**: 返回格式化的字符串结果,便于后续处理

**注意**: 此方法是数据库管理器的核心方法,所有具体的数据库操作都通过此方法执行。

"""
        pass

    def forward(self, statement: str) -> str:
        return self.execute_query(statement)

    @property
    def db_type(self) -> str:
        return self._db_type

    @property
    @abstractmethod
    def desc(self) -> str: pass

    @staticmethod
    def _is_dict_all_str(d):
        if not isinstance(d, dict):
            return False
        return all(isinstance(key, str) and (isinstance(value, str) or DBManager._is_dict_all_str(value))
                   for key, value in d.items())

    @staticmethod
    def _serialize_uncommon_type(obj):
        if not isinstance(obj, (int, str, float, bool, tuple, list, dict)):
            return str(obj)

execute_query(statement) abstractmethod

执行数据库查询语句的抽象方法。此方法需要由具体的数据库管理器子类实现,用于执行各种数据库操作。

Parameters:

  • statement

    要执行的数据库查询语句,可以是 SQL 语句或其他数据库特定的查询语言

此方法的特点:

  • 抽象方法: 需要在子类中实现具体的数据库操作逻辑
  • 统一接口: 为不同的数据库类型提供统一的查询接口
  • 错误处理: 子类实现应该包含适当的错误处理和状态报告
  • 结果格式化: 返回格式化的字符串结果,便于后续处理

注意: 此方法是数据库管理器的核心方法,所有具体的数据库操作都通过此方法执行。

Source code in lazyllm/tools/sql/db_manager.py
    @abstractmethod
    def execute_query(self, statement) -> str:
        """执行数据库查询语句的抽象方法。此方法需要由具体的数据库管理器子类实现,用于执行各种数据库操作。

Args:
    statement: 要执行的数据库查询语句,可以是 SQL 语句或其他数据库特定的查询语言

此方法的特点:

- **抽象方法**: 需要在子类中实现具体的数据库操作逻辑
- **统一接口**: 为不同的数据库类型提供统一的查询接口
- **错误处理**: 子类实现应该包含适当的错误处理和状态报告
- **结果格式化**: 返回格式化的字符串结果,便于后续处理

**注意**: 此方法是数据库管理器的核心方法,所有具体的数据库操作都通过此方法执行。

"""
        pass

lazyllm.tools.MongoDBManager

Bases: DBManager

Source code in lazyllm/tools/sql/mongodb_manager.py
class MongoDBManager(DBManager):
    MAX_TIMEOUT_MS = 5000

    def __init__(self, user: str, password: str, host: str, port: int, db_name: str, collection_name: str, **kwargs):
        super().__init__(db_type="mongodb")
        self._user = user
        self._password = password
        self._host = host
        self._port = port
        self._db_name = db_name
        self._collection_name = collection_name
        self._collection = None
        self._options_str = kwargs.get("options_str")
        self._conn_url = self._gen_conn_url()
        self._collection_desc_dict = kwargs.get("collection_desc_dict")

    @property
    def db_name(self):
        return self._db_name

    @property
    def collection_name(self):
        return self._collection_name

    def _gen_conn_url(self) -> str:
        password = quote_plus(self._password)
        conn_url = (f"{self._db_type}://{self._user}:{password}@{self._host}:{self._port}/"
                    f"{('?' + self._options_str) if self._options_str else ''}")
        return conn_url

    @contextmanager
    def get_client(self):
        client = pymongo.MongoClient(self._conn_url, serverSelectionTimeoutMS=self.MAX_TIMEOUT_MS)
        try:
            yield client
        finally:
            client.close()

    @property
    def desc(self):
        if self._desc is None:
            self.set_desc(schema_desc_dict=self._collection_desc_dict)
        return self._desc

    def set_desc(self, schema_desc_dict: dict):
        self._collection_desc_dict = schema_desc_dict
        if schema_desc_dict is None:
            with self.get_client() as client:
                egs_one = client[self._db_name][self._collection_name].find_one()
                if egs_one is not None:
                    self._desc = "Collection Example:\n"
                    self._desc += json.dumps(egs_one, ensure_ascii=False, indent=4)
        else:
            self._desc = ""
            try:
                collection_desc = CollectionDesc.model_validate(schema_desc_dict)
            except pydantic.ValidationError as e:
                raise ValueError(f"Validate input schema_desc_dict failed: {str(e)}")
            if not self._is_dict_all_str(collection_desc.schema_type):
                raise ValueError("schema_type shouble be str or nested str dict")
            if not self._is_dict_all_str(collection_desc.schema_desc):
                raise ValueError("schema_desc shouble be str or nested str dict")
            if collection_desc.summary:
                self._desc += f"Collection summary: {collection_desc.summary}\n"
            self._desc += "Collection schema:\n"
            self._desc += json.dumps(collection_desc.schema_type, ensure_ascii=False, indent=4)
            self._desc += "Collection schema description:\n"
            self._desc += json.dumps(collection_desc.schema_type, ensure_ascii=False, indent=4)

    def check_connection(self) -> DBResult:
        try:
            with pymongo.MongoClient(self._conn_url, serverSelectionTimeoutMS=self.MAX_TIMEOUT_MS) as client:
                _ = client.server_info()
            return DBResult()
        except Exception as e:
            return DBResult(status=DBStatus.FAIL, detail=str(e))

    def execute_query(self, statement) -> str:
        str_result = ""
        try:
            pipeline_list = json.loads(statement)
            with self.get_client() as client:
                collection = client[self._db_name][self._collection_name]
                result = list(collection.aggregate(pipeline_list))
                str_result = json.dumps(result, ensure_ascii=False, default=self._serialize_uncommon_type)
        except Exception as e:
            str_result = f"MongoDB ERROR: {str(e)}"
        return str_result

lazyllm.tools.HttpTool

Bases: HttpRequest

用于访问第三方服务和执行自定义代码的模块。参数中的 paramsheaders 的 value,以及 body 中可以包含形如 {{variable}} 这样用两个花括号标记的模板变量,然后在调用的时候通过参数来替换模板中的值。参考 [[lazyllm.tools.HttpTool.forward]] 中的使用说明。

Parameters:

  • method (str, default: None ) –

    指定 http 请求方法,参考 https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods

  • url (str, default: None ) –

    要访问的 url。如果该字段为空,则表示该模块不需要访问第三方服务。

  • params (Dict[str, str], default: None ) –

    请求 url 需要填充的 params 字段。如果 url 为空,该字段会被忽略。

  • headers (Dict[str, str], default: None ) –

    访问 url 需要填充的 header 字段。如果 url 为空,该字段会被忽略。

  • body (Dict[str, str], default: None ) –

    请求 url 需要填充的 body 字段。如果 url 为空,该字段会被忽略。

  • timeout (int, default: 10 ) –

    请求超时时间,单位是秒,默认值是 10。

  • proxies (Dict[str, str], default: None ) –

    指定请求 url 时所使用的代理。代理格式参考 https://www.python-httpx.org/advanced/proxies

  • code_str (str, default: None ) –

    一个字符串,包含用户定义的函数。如果参数 url 为空,则直接执行该函数,执行时所有的参数都会转发给该函数;如果 url 不为空,该函数的参数为请求 url 返回的结果,此时该函数作为 url 返回后的后处理函数。

  • vars_for_code (Dict[str, Any], default: None ) –

    一个字典,传入运行 code 所需的依赖及变量。

  • outputs (Optional[List[str]], default: None ) –

    期望提取的输出字段名。

  • extract_from_result (Optional[bool], default: None ) –

    是否从响应字典中直接提取指定字段。

Examples:

from lazyllm.tools import HttpTool

code_str = "def identity(content): return content"
tool = HttpTool(method='GET', url='http://www.sensetime.com/', code_str=code_str)
ret = tool()
Source code in lazyllm/tools/tools/http_tool.py
class HttpTool(HttpRequest):
    """
用于访问第三方服务和执行自定义代码的模块。参数中的 `params` 和 `headers` 的 value,以及 `body` 中可以包含形如 `{{variable}}` 这样用两个花括号标记的模板变量,然后在调用的时候通过参数来替换模板中的值。参考 [[lazyllm.tools.HttpTool.forward]] 中的使用说明。

Args:
    method (str, optional): 指定 http 请求方法,参考 `https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods`。
    url (str, optional): 要访问的 url。如果该字段为空,则表示该模块不需要访问第三方服务。
    params (Dict[str, str], optional): 请求 url 需要填充的 params 字段。如果 url 为空,该字段会被忽略。
    headers (Dict[str, str], optional): 访问 url 需要填充的 header 字段。如果 url 为空,该字段会被忽略。
    body (Dict[str, str], optional): 请求 url 需要填充的 body 字段。如果 url 为空,该字段会被忽略。
    timeout (int): 请求超时时间,单位是秒,默认值是 10。
    proxies (Dict[str, str], optional): 指定请求 url 时所使用的代理。代理格式参考 `https://www.python-httpx.org/advanced/proxies`。
    code_str (str, optional): 一个字符串,包含用户定义的函数。如果参数 `url` 为空,则直接执行该函数,执行时所有的参数都会转发给该函数;如果 `url` 不为空,该函数的参数为请求 url 返回的结果,此时该函数作为 url 返回后的后处理函数。
    vars_for_code (Dict[str, Any]): 一个字典,传入运行 code 所需的依赖及变量。
    outputs (Optional[List[str]]): 期望提取的输出字段名。
    extract_from_result (Optional[bool]): 是否从响应字典中直接提取指定字段。


Examples:

    from lazyllm.tools import HttpTool

    code_str = "def identity(content): return content"
    tool = HttpTool(method='GET', url='http://www.sensetime.com/', code_str=code_str)
    ret = tool()
    """
    def __init__(self,
                 method: Optional[str] = None,
                 url: Optional[str] = None,
                 params: Optional[Dict[str, str]] = None,
                 headers: Optional[Dict[str, str]] = None,
                 body: Optional[str] = None,
                 timeout: int = 10,
                 proxies: Optional[Dict[str, str]] = None,
                 code_str: Optional[str] = None,
                 vars_for_code: Optional[Dict[str, Any]] = None,
                 outputs: Optional[List[str]] = None,
                 extract_from_result: Optional[bool] = None):
        super().__init__(method, url, '', headers, params, body, timeout, proxies)
        self._has_http = True if url else False
        self._compiled_func = (compile_func(code_str, vars_for_code) if code_str else
                               (lambda x: json.loads(x['content'])) if self._has_http else None)
        self._outputs, self._extract_from_result = outputs, extract_from_result
        if extract_from_result:
            assert outputs, 'Output information is necessary to extract output parameters'
            assert len(outputs) == 1, 'When the number of outputs is greater than 1, no manual setting is required'

    def _get_result(self, res):
        if self._extract_from_result or (isinstance(res, dict) and len(self._outputs) > 1):
            assert isinstance(res, dict), 'The result of the tool should be a dict type'
            r = package(res.get(key) for key in self._outputs)
            return r[0] if len(r) == 1 else r
        if len(self._outputs) > 1:
            assert isinstance(res, (tuple, list)), 'The result of the tool should be tuple or list'
            assert len(res) == len(self._outputs), 'The number of outputs is inconsistent with expectations'
            return package(res)
        return res

    def forward(self, *args, **kwargs):
        """
用于执行初始化时指定的操作:请求指定的 url 或者执行传入的函数。一般不直接调用,而是通过基类的 `__call__` 来调用。如果构造函数的 `url` 参数不为空,则传入的所有参数都会作为变量,用于替换在构造函数中使用 `{{}}` 标记的模板参数;如果构造函数的参数 `url` 为空,并且 `code_str` 不为空,则传入的所有参数都会作为 `code_str` 中所定义函数的参数。


Examples:

    from lazyllm.tools import HttpTool

    code_str = "def exp(v, n): return v ** n"
    tool = HttpTool(code_str=code_str)
    assert tool(v=10, n=2) == 100
    """
        if not self._compiled_func: return None
        if self._has_http:
            res = super().forward(*args, **kwargs)
            if int(res['status_code']) >= 400:
                raise RuntimeError(f'HttpRequest error, status code is {res["status_code"]}.')
            args, kwargs = (res,), {}
        res = self._compiled_func(*args, **kwargs)
        return self._get_result(res) if self._outputs else res

forward(*args, **kwargs)

用于执行初始化时指定的操作:请求指定的 url 或者执行传入的函数。一般不直接调用,而是通过基类的 __call__ 来调用。如果构造函数的 url 参数不为空,则传入的所有参数都会作为变量,用于替换在构造函数中使用 {{}} 标记的模板参数;如果构造函数的参数 url 为空,并且 code_str 不为空,则传入的所有参数都会作为 code_str 中所定义函数的参数。

Examples:

from lazyllm.tools import HttpTool

code_str = "def exp(v, n): return v ** n"
tool = HttpTool(code_str=code_str)
assert tool(v=10, n=2) == 100
Source code in lazyllm/tools/tools/http_tool.py
    def forward(self, *args, **kwargs):
        """
用于执行初始化时指定的操作:请求指定的 url 或者执行传入的函数。一般不直接调用,而是通过基类的 `__call__` 来调用。如果构造函数的 `url` 参数不为空,则传入的所有参数都会作为变量,用于替换在构造函数中使用 `{{}}` 标记的模板参数;如果构造函数的参数 `url` 为空,并且 `code_str` 不为空,则传入的所有参数都会作为 `code_str` 中所定义函数的参数。


Examples:

    from lazyllm.tools import HttpTool

    code_str = "def exp(v, n): return v ** n"
    tool = HttpTool(code_str=code_str)
    assert tool(v=10, n=2) == 100
    """
        if not self._compiled_func: return None
        if self._has_http:
            res = super().forward(*args, **kwargs)
            if int(res['status_code']) >= 400:
                raise RuntimeError(f'HttpRequest error, status code is {res["status_code"]}.')
            args, kwargs = (res,), {}
        res = self._compiled_func(*args, **kwargs)
        return self._get_result(res) if self._outputs else res

lazyllm.tools.agent.functionCall.StreamResponse

StreamResponse类用于封装带有前缀和颜色配置的流式输出行为。
当启用流式模式时,调用实例会将带颜色的文本推送到文件系统队列中,用于异步处理或显示。

Parameters:

  • prefix (str) –

    输出内容前的前缀文本,通常用于标识信息来源或类别。

  • prefix_color (Optional[str], default: None ) –

    前缀文本的颜色,支持终端颜色代码,默认无颜色。

  • color (Optional[str], default: None ) –

    主体内容文本颜色,支持终端颜色代码,默认无颜色。

  • stream (bool, default: False ) –

    是否启用流式输出模式,启用后会将文本推送至文件系统队列,默认关闭。

Examples:

>>> from lazyllm.tools.agent.functionCall import StreamResponse
>>> resp = StreamResponse(prefix="[INFO]", prefix_color="green", color="white", stream=True)
>>> resp("Hello, world!")
Hello, world!
Source code in lazyllm/tools/agent/functionCall.py
class StreamResponse():
    """StreamResponse类用于封装带有前缀和颜色配置的流式输出行为。  
当启用流式模式时,调用实例会将带颜色的文本推送到文件系统队列中,用于异步处理或显示。

Args:
    prefix (str): 输出内容前的前缀文本,通常用于标识信息来源或类别。
    prefix_color (Optional[str]): 前缀文本的颜色,支持终端颜色代码,默认无颜色。
    color (Optional[str]): 主体内容文本颜色,支持终端颜色代码,默认无颜色。
    stream (bool): 是否启用流式输出模式,启用后会将文本推送至文件系统队列,默认关闭。


Examples:
    >>> from lazyllm.tools.agent.functionCall import StreamResponse
    >>> resp = StreamResponse(prefix="[INFO]", prefix_color="green", color="white", stream=True)
    >>> resp("Hello, world!")
    Hello, world!
    """
    def __init__(self, prefix: str, prefix_color: str = None, color: str = None, stream: bool = False):
        self.stream = stream
        self.prefix = prefix
        self.prefix_color = prefix_color
        self.color = color

    def __call__(self, *inputs):
        if self.stream: FileSystemQueue().enqueue(colored_text(f'\n{self.prefix}\n', self.prefix_color))
        if len(inputs) == 1:
            if self.stream: FileSystemQueue().enqueue(colored_text(f'{inputs[0]}', self.color))
            return inputs[0]
        if self.stream: FileSystemQueue().enqueue(colored_text(f'{inputs}', self.color))
        return package(*inputs)

lazyllm.tools.MCPClient

Bases: object

Source code in lazyllm/tools/mcp/client.py
class MCPClient(object):
    def __init__(
        self,
        command_or_url: str,
        args: Optional[list[str]] = None,
        env: dict[str, str] = None,
        headers: dict[str, Any] = None,
        timeout: float = 5,
    ):
        self._command_or_url = command_or_url
        self._args = args or []
        self._env = env
        self._headers = headers
        self._timeout = timeout

    @asynccontextmanager
    async def _run_session(self):
        if urlparse(self._command_or_url).scheme in ("http", "https"):
            spec = importlib.util.find_spec("mcp.client.sse")
            if spec is None:
                raise ImportError(
                    "Please install mcp to use mcp module. "
                    "You can install it with `pip install mcp`"
                )
            sse_module = importlib.util.module_from_spec(spec)
            spec.loader.exec_module(sse_module)
            sse_client = sse_module.sse_client

            async with sse_client(
                url=self._command_or_url,
                headers=self._headers,
                timeout=self._timeout
            ) as streams:
                async with mcp.ClientSession(*streams) as session:
                    await session.initialize()
                    yield session
        else:
            server_parameters = mcp.StdioServerParameters(
                command=self._command_or_url, args=self._args, env=self._env
            )
            async with mcp.stdio_client(server_parameters) as streams:
                async with mcp.ClientSession(*streams) as session:
                    await session.initialize()
                    yield session

    async def call_tool(self, tool_name: str, arguments: dict):
        async with self._run_session() as session:
            return await session.call_tool(tool_name, arguments)

    async def list_tools(self):
        async with self._run_session() as session:
            return await session.list_tools()

    async def aget_tools(self, allowed_tools: list[str] = None):
        res = await self.list_tools()
        mcp_tools = getattr(res, "tools", [])
        if allowed_tools:
            mcp_tools = [tool for tool in mcp_tools if tool.name in allowed_tools]

        return [generate_lazyllm_tool(self, tool) for tool in mcp_tools]

    def get_tools(self, allowed_tools: list[str] = None):
        return patch_sync(self.aget_tools)(allowed_tools=allowed_tools)

    async def deploy(self, sse_settings: SseServerSettings):
        async with self._run_session() as session:
            await start_sse_server(session, sse_settings)

lazyllm.tools.tools.GoogleSearch

Bases: HttpTool

通过 Google 搜索指定的关键词。

Parameters:

  • custom_search_api_key (str) –

    用户申请的 Google API key。

  • search_engine_id (str) –

    用户创建的用于检索的搜索引擎 id。

  • timeout (int, default: 10 ) –

    搜索请求的超时时间,单位是秒,默认是 10。

  • proxies (Dict[str, str], default: None ) –

    请求时所用的代理服务。格式参考 https://www.python-httpx.org/advanced/proxies

Examples:

from lazyllm.tools.tools import GoogleSearch

key = '<your_google_search_api_key>'
cx = '<your_search_engine_id>'

google = GoogleSearch(custom_search_api_key=key, search_engine_id=cx)
Source code in lazyllm/tools/tools/google_search.py
class GoogleSearch(HttpTool):
    """
通过 Google 搜索指定的关键词。

Args:
    custom_search_api_key (str): 用户申请的 Google API key。
    search_engine_id (str): 用户创建的用于检索的搜索引擎 id。
    timeout (int): 搜索请求的超时时间,单位是秒,默认是 10。
    proxies (Dict[str, str], optional): 请求时所用的代理服务。格式参考 `https://www.python-httpx.org/advanced/proxies`。


Examples:

    from lazyllm.tools.tools import GoogleSearch

    key = '<your_google_search_api_key>'
    cx = '<your_search_engine_id>'

    google = GoogleSearch(custom_search_api_key=key, search_engine_id=cx)
    """
    # @param proxies refer to https://www.python-httpx.org/advanced/proxies
    def __init__(self, custom_search_api_key: str, search_engine_id: str,
                 timeout=10, proxies: Optional[Dict] = None):
        # refer to https://developers.google.com/custom-search/v1/reference/rest/v1/cse/list?hl=zh-cn
        params = {
            'key': custom_search_api_key,
            'cx': '{{search_engine_id}}',
            'q': '{{query}}',
            'dateRestrict': '{{date_restrict}}',
            'start': 0,
            'num': 10,
        }
        super().__init__(method='GET', url='https://customsearch.googleapis.com/customsearch/v1',
                         params=params, timeout=timeout, proxies=proxies)
        self._search_engine_id = search_engine_id

    def forward(self, query: str, date_restrict: str = 'm1',
                search_engine_id: Optional[str] = None) -> Optional[Dict]:
        """
执行搜索请求。

Args:
    query (str): 要检索的关键词。
    date_restrict (str): 要检索内容的时效性。默认检索一个月内的网页(`m1`)。参数格式可以参考 `https://developers.google.com/custom-search/v1/reference/rest/v1/cse/list?hl=zh-cn`。
    search_engine_id (str, optional): 用于检索的搜索引擎 id。如果该值为空,则使用构造函数中传入的值。


Examples:

    from lazyllm.tools.tools import GoogleSearch

    key = '<your_google_search_api_key>'
    cx = '<your_search_engine_id>'

    google = GoogleSearch(key, cx)
    res = google(query='商汤科技', date_restrict='m1')
    """
        if not search_engine_id:
            search_engine_id = self._search_engine_id

        return super().forward(query=query, search_engine_id=search_engine_id,
                               date_restrict=date_restrict)

forward(query, date_restrict='m1', search_engine_id=None)

执行搜索请求。

Parameters:

  • query (str) –

    要检索的关键词。

  • date_restrict (str, default: 'm1' ) –

    要检索内容的时效性。默认检索一个月内的网页(m1)。参数格式可以参考 https://developers.google.com/custom-search/v1/reference/rest/v1/cse/list?hl=zh-cn

  • search_engine_id (str, default: None ) –

    用于检索的搜索引擎 id。如果该值为空,则使用构造函数中传入的值。

Examples:

from lazyllm.tools.tools import GoogleSearch

key = '<your_google_search_api_key>'
cx = '<your_search_engine_id>'

google = GoogleSearch(key, cx)
res = google(query='商汤科技', date_restrict='m1')
Source code in lazyllm/tools/tools/google_search.py
    def forward(self, query: str, date_restrict: str = 'm1',
                search_engine_id: Optional[str] = None) -> Optional[Dict]:
        """
执行搜索请求。

Args:
    query (str): 要检索的关键词。
    date_restrict (str): 要检索内容的时效性。默认检索一个月内的网页(`m1`)。参数格式可以参考 `https://developers.google.com/custom-search/v1/reference/rest/v1/cse/list?hl=zh-cn`。
    search_engine_id (str, optional): 用于检索的搜索引擎 id。如果该值为空,则使用构造函数中传入的值。


Examples:

    from lazyllm.tools.tools import GoogleSearch

    key = '<your_google_search_api_key>'
    cx = '<your_search_engine_id>'

    google = GoogleSearch(key, cx)
    res = google(query='商汤科技', date_restrict='m1')
    """
        if not search_engine_id:
            search_engine_id = self._search_engine_id

        return super().forward(query=query, search_engine_id=search_engine_id,
                               date_restrict=date_restrict)

lazyllm.tools.tools.tencent_search.TencentSearch

Bases: ModuleBase

这是一个搜索增强工具。

Examples:

from lazyllm.tools.tools import TencentSearch
secret_id = '<your_secret_id>'
secret_key = '<your_secret_key>'
searcher = TencentSearch(secret_id, secret_key)
Source code in lazyllm/tools/tools/tencent_search.py
class TencentSearch(ModuleBase):
    """
这是一个搜索增强工具。


Examples:

    from lazyllm.tools.tools import TencentSearch
    secret_id = '<your_secret_id>'
    secret_key = '<your_secret_key>'
    searcher = TencentSearch(secret_id, secret_key)
    """
    def __init__(self, secret_id, secret_key):
        super().__init__()
        from tencentcloud.common.common_client import CommonClient
        from tencentcloud.common import credential
        from tencentcloud.common.profile.client_profile import ClientProfile
        from tencentcloud.common.profile.http_profile import HttpProfile

        self.cred = credential.Credential(secret_id, secret_key)
        httpProfile = HttpProfile()
        httpProfile.endpoint = "tms.tencentcloudapi.com"
        clientProfile = ClientProfile()
        clientProfile.httpProfile = httpProfile
        self.headers = {"X-TC-Action": "SearchPro"}
        self.common_client = CommonClient(
            "tms", '2020-12-29', self.cred, "", profile=clientProfile)

    def forward(self, query: str):
        """
搜索用户输入的查询。

Args:
    query (str): 用户待查询的内容。


Examples:

    from lazyllm.tools.tools import TencentSearch
    secret_id = '<your_secret_id>'
    secret_key = '<your_secret_key>'
    searcher = TencentSearch(secret_id, secret_key)
    res = searcher('calculus')
    """
        try:
            res_dict = self.common_client.call_json("SearchPro", {'Query': query, 'Mode': 2}, headers=self.headers)
            res = package(res_dict["Response"]["Pages"])
        except Exception as err:
            lazyllm.LOG.error("Request Tencent Search meets error: ", err)
            res = package()
        return res

forward(query)

搜索用户输入的查询。

Parameters:

  • query (str) –

    用户待查询的内容。

Examples:

from lazyllm.tools.tools import TencentSearch
secret_id = '<your_secret_id>'
secret_key = '<your_secret_key>'
searcher = TencentSearch(secret_id, secret_key)
res = searcher('calculus')
Source code in lazyllm/tools/tools/tencent_search.py
    def forward(self, query: str):
        """
搜索用户输入的查询。

Args:
    query (str): 用户待查询的内容。


Examples:

    from lazyllm.tools.tools import TencentSearch
    secret_id = '<your_secret_id>'
    secret_key = '<your_secret_key>'
    searcher = TencentSearch(secret_id, secret_key)
    res = searcher('calculus')
    """
        try:
            res_dict = self.common_client.call_json("SearchPro", {'Query': query, 'Mode': 2}, headers=self.headers)
            res = package(res_dict["Response"]["Pages"])
        except Exception as err:
            lazyllm.LOG.error("Request Tencent Search meets error: ", err)
            res = package()
        return res

lazyllm.tools.rag.web.WebUi

基于 Gradio 的知识库文件管理 Web UI 工具类。

该类用于构建一个简单的 Web 界面,支持创建分组、上传文件、列出/删除分组或文件,并通过 RESTful API 与后端交互。支持快速集成与展示文件管理能力。

Parameters:

  • base_url (str) –

    后端 API 服务的基础地址。

Source code in lazyllm/tools/rag/web.py
class WebUi:
    """基于 Gradio 的知识库文件管理 Web UI 工具类。

该类用于构建一个简单的 Web 界面,支持创建分组、上传文件、列出/删除分组或文件,并通过 RESTful API 与后端交互。支持快速集成与展示文件管理能力。

Args:
    base_url (str): 后端 API 服务的基础地址。
"""
    def __init__(self, base_url) -> None:
        self.base_url = base_url

    def basic_headers(self, content_type=True):
        """
生成通用的 HTTP 请求头。

Args:
    content_type (bool): 是否包含 Content-Type 头信息(默认为 True)。

Returns:
    dict: HTTP 请求头字典。
"""
        return {
            "accept": "application/json",
            "Content-Type": "application/json" if content_type else None,
        }

    def muti_headers(
        self,
    ):
        """
生成用于上传文件的 HTTP 请求头。

Returns:
    dict: HTTP 请求头字典。
"""
        return {"accept": "application/json"}

    def post_request(self, url, data):
        """
发送 POST 请求。

Args:
    url (str): 请求地址。
    data (dict): 请求数据,将被转为 JSON。

Returns:
    dict: 响应结果的 JSON。
"""
        response = requests.post(
            url, headers=self.basic_headers(), data=json.dumps(data)
        )
        return response.json()

    def get_request(self, url):
        """
发送 GET 请求。

Args:
    url (str): 请求地址。

Returns:
    dict: 响应结果的 JSON。
"""
        response = requests.get(url, headers=self.basic_headers(False))
        return response.json()

    def new_group(self, group_name: str):
        """
创建新的文件分组。

Args:
    group_name (str): 分组名称。

Returns:
    str: 创建结果的提示信息。
"""
        response = requests.post(
            f"{self.base_url}/new_group?group_name={group_name}",
            headers=self.basic_headers(True),
        )
        return response.json()["msg"]

    def delete_group(self, group_name: str):
        """
删除指定的文件分组。

Args:
    group_name (str): 分组名称。

Returns:
    str: 删除结果信息。
"""
        response = requests.post(
            f"{self.base_url}/delete_group?group_name={group_name}",
            headers=self.basic_headers(True),
        )
        return response.json()["msg"]

    def list_groups(self):
        """
列出所有文件分组。

Returns:
    List[str]: 分组名称列表。
"""
        response = requests.get(
            f"{self.base_url}/list_kb_groups", headers=self.basic_headers(False)
        )
        return response.json()["data"]

    def upload_files(self, group_name: str, override: bool = True):
        """
向指定分组上传文件。

Args:
    group_name (str): 分组名称。
    override (bool): 是否覆盖已存在的文件(默认 True)。

Returns:
    Any: 后端返回的上传结果数据。
"""
        response = requests.post(
            f"{self.base_url}/upload_files?group_name={group_name}&override={override}",
            headers=self.basic_headers(True),
        )
        return response.json()["data"]

    def list_files_in_group(self, group_name: str):
        """
列出指定分组下的所有文件。

Args:
    group_name (str): 分组名称。

Returns:
    List: 文件信息列表。
"""
        response = requests.get(
            f"{self.base_url}/list_files_in_group?group_name={group_name}&alive=True",
            headers=self.basic_headers(False),
        )
        return response.json()["data"]

    def delete_file(self, group_name: str, file_ids: list[str]):
        """
从指定分组中删除文件。

Args:
    group_name (str): 分组名称。
    file_ids (List[str]): 要删除的文件 ID 列表。

Returns:
    str: 删除结果提示。
"""
        response = requests.post(
            f"{self.base_url}/delete_files_from_group",
            headers=self.basic_headers(True),
            json={"group_name": group_name, "file_ids": file_ids}
        )
        return response.json()["msg"]

    def gr_show_list(self, str_list: list, list_name: Union[str, list]):
        """
以 Gradio 表格的形式展示字符串列表。

Args:
    str_list (List): 字符串或子项列表。
    list_name (Union[str, List]): 表头名称或列名列表。

Returns:
    gr.DataFrame: Gradio 表格组件。
"""
        if isinstance(list_name, str):
            headers = ["index", list_name]
            value = [[index, str_list[index]] for index in range(len(str_list))]
        else:
            headers = ["index"] + list_name
            value = [[index] + str_list[index:index + len(list_name)] for index in range(len(str_list))]
        return gr.DataFrame(headers=headers, value=value)

    def create_ui(self):
        """
构建基于 Gradio 的文件管理图形界面,包含分组列表、上传、查看、删除等功能标签页。

Returns:
    gr.Blocks: 完整的 Gradio UI 应用实例。
"""
        with gr.Blocks(analytics_enabled=False) as demo:
            with gr.Tabs():
                select_group_list = []

                with gr.TabItem("分组列表"):
                    select_group = self.gr_show_list(
                        self.list_groups(), list_name="group_name"
                    )
                    select_group_list.append(select_group)

                with gr.TabItem("上传文件"):

                    def _upload_files(group_name, files):

                        files_to_upload = [
                            ("files", (os.path.basename(file), open(file, "rb")))
                            for file in files
                        ]

                        url = f"{self.base_url}/add_files_to_group?group_name={group_name}&override=true"
                        response = requests.post(
                            url, files=files_to_upload, headers=self.muti_headers()
                        )
                        response.raise_for_status()
                        response_data = response.json()
                        gr.Info(str(response_data["msg"]))

                        for _, (_, file_obj) in files_to_upload:
                            file_obj.close()

                    select_group = gr.Dropdown(self.list_groups(), label="选择分组")
                    select_group.change(lambda x: x, inputs=select_group, outputs=None)

                    up_files = gr.Files(label="上传文件")
                    up_btn = gr.Button("上传")
                    up_btn.click(
                        _upload_files,
                        inputs=[select_group, up_files],
                        outputs=None,
                    )

                    select_group_list.append(select_group)

                with gr.TabItem("分组文件列表"):
                    def _list_group_files(group_name):
                        file_list = self.list_files_in_group(group_name)
                        values = [[i] + file_list[i][:2] for i in range(len(file_list))]
                        return gr.update(
                            value=values
                        )

                    select_group = gr.Dropdown(self.list_groups(), label="选择分组")
                    show_list = self.gr_show_list([], list_name=["file_id", "file_name"])
                    select_group.change(
                        fn=_list_group_files, inputs=select_group, outputs=show_list
                    )
                    select_group_list.append(select_group)

                with gr.TabItem("删除文件"):

                    def _list_group_files(group_name):
                        file_list = self.list_files_in_group(group_name)
                        file_list = [','.join(file[:2]) for file in file_list]
                        return gr.update(choices=file_list)

                    select_group = gr.Dropdown(self.list_groups(), label="选择分组")
                    select_file = gr.Dropdown([], label="选择文件")
                    select_group.change(
                        fn=_list_group_files, inputs=select_group, outputs=select_file
                    )
                    delete_btn = gr.Button("删除")

                    def _delete_file(group_name, select_file):
                        file_ids = [select_file.split(',')[0]]
                        gr.Info(self.delete_file(group_name, file_ids))
                        return _list_group_files(group_name)

                    delete_btn.click(
                        fn=_delete_file,
                        inputs=[select_group, select_file],
                        outputs=select_file,
                    )
                    select_group_list.append(select_group)

        return demo

basic_headers(content_type=True)

生成通用的 HTTP 请求头。

Parameters:

  • content_type (bool, default: True ) –

    是否包含 Content-Type 头信息(默认为 True)。

Returns:

  • dict

    HTTP 请求头字典。

Source code in lazyllm/tools/rag/web.py
    def basic_headers(self, content_type=True):
        """
生成通用的 HTTP 请求头。

Args:
    content_type (bool): 是否包含 Content-Type 头信息(默认为 True)。

Returns:
    dict: HTTP 请求头字典。
"""
        return {
            "accept": "application/json",
            "Content-Type": "application/json" if content_type else None,
        }

create_ui()

构建基于 Gradio 的文件管理图形界面,包含分组列表、上传、查看、删除等功能标签页。

Returns:

  • gr.Blocks: 完整的 Gradio UI 应用实例。

Source code in lazyllm/tools/rag/web.py
    def create_ui(self):
        """
构建基于 Gradio 的文件管理图形界面,包含分组列表、上传、查看、删除等功能标签页。

Returns:
    gr.Blocks: 完整的 Gradio UI 应用实例。
"""
        with gr.Blocks(analytics_enabled=False) as demo:
            with gr.Tabs():
                select_group_list = []

                with gr.TabItem("分组列表"):
                    select_group = self.gr_show_list(
                        self.list_groups(), list_name="group_name"
                    )
                    select_group_list.append(select_group)

                with gr.TabItem("上传文件"):

                    def _upload_files(group_name, files):

                        files_to_upload = [
                            ("files", (os.path.basename(file), open(file, "rb")))
                            for file in files
                        ]

                        url = f"{self.base_url}/add_files_to_group?group_name={group_name}&override=true"
                        response = requests.post(
                            url, files=files_to_upload, headers=self.muti_headers()
                        )
                        response.raise_for_status()
                        response_data = response.json()
                        gr.Info(str(response_data["msg"]))

                        for _, (_, file_obj) in files_to_upload:
                            file_obj.close()

                    select_group = gr.Dropdown(self.list_groups(), label="选择分组")
                    select_group.change(lambda x: x, inputs=select_group, outputs=None)

                    up_files = gr.Files(label="上传文件")
                    up_btn = gr.Button("上传")
                    up_btn.click(
                        _upload_files,
                        inputs=[select_group, up_files],
                        outputs=None,
                    )

                    select_group_list.append(select_group)

                with gr.TabItem("分组文件列表"):
                    def _list_group_files(group_name):
                        file_list = self.list_files_in_group(group_name)
                        values = [[i] + file_list[i][:2] for i in range(len(file_list))]
                        return gr.update(
                            value=values
                        )

                    select_group = gr.Dropdown(self.list_groups(), label="选择分组")
                    show_list = self.gr_show_list([], list_name=["file_id", "file_name"])
                    select_group.change(
                        fn=_list_group_files, inputs=select_group, outputs=show_list
                    )
                    select_group_list.append(select_group)

                with gr.TabItem("删除文件"):

                    def _list_group_files(group_name):
                        file_list = self.list_files_in_group(group_name)
                        file_list = [','.join(file[:2]) for file in file_list]
                        return gr.update(choices=file_list)

                    select_group = gr.Dropdown(self.list_groups(), label="选择分组")
                    select_file = gr.Dropdown([], label="选择文件")
                    select_group.change(
                        fn=_list_group_files, inputs=select_group, outputs=select_file
                    )
                    delete_btn = gr.Button("删除")

                    def _delete_file(group_name, select_file):
                        file_ids = [select_file.split(',')[0]]
                        gr.Info(self.delete_file(group_name, file_ids))
                        return _list_group_files(group_name)

                    delete_btn.click(
                        fn=_delete_file,
                        inputs=[select_group, select_file],
                        outputs=select_file,
                    )
                    select_group_list.append(select_group)

        return demo

delete_file(group_name, file_ids)

从指定分组中删除文件。

Parameters:

  • group_name (str) –

    分组名称。

  • file_ids (List[str]) –

    要删除的文件 ID 列表。

Returns:

  • str

    删除结果提示。

Source code in lazyllm/tools/rag/web.py
    def delete_file(self, group_name: str, file_ids: list[str]):
        """
从指定分组中删除文件。

Args:
    group_name (str): 分组名称。
    file_ids (List[str]): 要删除的文件 ID 列表。

Returns:
    str: 删除结果提示。
"""
        response = requests.post(
            f"{self.base_url}/delete_files_from_group",
            headers=self.basic_headers(True),
            json={"group_name": group_name, "file_ids": file_ids}
        )
        return response.json()["msg"]

delete_group(group_name)

删除指定的文件分组。

Parameters:

  • group_name (str) –

    分组名称。

Returns:

  • str

    删除结果信息。

Source code in lazyllm/tools/rag/web.py
    def delete_group(self, group_name: str):
        """
删除指定的文件分组。

Args:
    group_name (str): 分组名称。

Returns:
    str: 删除结果信息。
"""
        response = requests.post(
            f"{self.base_url}/delete_group?group_name={group_name}",
            headers=self.basic_headers(True),
        )
        return response.json()["msg"]

get_request(url)

发送 GET 请求。

Parameters:

  • url (str) –

    请求地址。

Returns:

  • dict

    响应结果的 JSON。

Source code in lazyllm/tools/rag/web.py
    def get_request(self, url):
        """
发送 GET 请求。

Args:
    url (str): 请求地址。

Returns:
    dict: 响应结果的 JSON。
"""
        response = requests.get(url, headers=self.basic_headers(False))
        return response.json()

gr_show_list(str_list, list_name)

以 Gradio 表格的形式展示字符串列表。

Parameters:

  • str_list (List) –

    字符串或子项列表。

  • list_name (Union[str, List]) –

    表头名称或列名列表。

Returns:

  • gr.DataFrame: Gradio 表格组件。

Source code in lazyllm/tools/rag/web.py
    def gr_show_list(self, str_list: list, list_name: Union[str, list]):
        """
以 Gradio 表格的形式展示字符串列表。

Args:
    str_list (List): 字符串或子项列表。
    list_name (Union[str, List]): 表头名称或列名列表。

Returns:
    gr.DataFrame: Gradio 表格组件。
"""
        if isinstance(list_name, str):
            headers = ["index", list_name]
            value = [[index, str_list[index]] for index in range(len(str_list))]
        else:
            headers = ["index"] + list_name
            value = [[index] + str_list[index:index + len(list_name)] for index in range(len(str_list))]
        return gr.DataFrame(headers=headers, value=value)

list_files_in_group(group_name)

列出指定分组下的所有文件。

Parameters:

  • group_name (str) –

    分组名称。

Returns:

  • List

    文件信息列表。

Source code in lazyllm/tools/rag/web.py
    def list_files_in_group(self, group_name: str):
        """
列出指定分组下的所有文件。

Args:
    group_name (str): 分组名称。

Returns:
    List: 文件信息列表。
"""
        response = requests.get(
            f"{self.base_url}/list_files_in_group?group_name={group_name}&alive=True",
            headers=self.basic_headers(False),
        )
        return response.json()["data"]

list_groups()

列出所有文件分组。

Returns:

  • List[str]: 分组名称列表。

Source code in lazyllm/tools/rag/web.py
    def list_groups(self):
        """
列出所有文件分组。

Returns:
    List[str]: 分组名称列表。
"""
        response = requests.get(
            f"{self.base_url}/list_kb_groups", headers=self.basic_headers(False)
        )
        return response.json()["data"]

muti_headers()

生成用于上传文件的 HTTP 请求头。

Returns:

  • dict

    HTTP 请求头字典。

Source code in lazyllm/tools/rag/web.py
    def muti_headers(
        self,
    ):
        """
生成用于上传文件的 HTTP 请求头。

Returns:
    dict: HTTP 请求头字典。
"""
        return {"accept": "application/json"}

new_group(group_name)

创建新的文件分组。

Parameters:

  • group_name (str) –

    分组名称。

Returns:

  • str

    创建结果的提示信息。

Source code in lazyllm/tools/rag/web.py
    def new_group(self, group_name: str):
        """
创建新的文件分组。

Args:
    group_name (str): 分组名称。

Returns:
    str: 创建结果的提示信息。
"""
        response = requests.post(
            f"{self.base_url}/new_group?group_name={group_name}",
            headers=self.basic_headers(True),
        )
        return response.json()["msg"]

post_request(url, data)

发送 POST 请求。

Parameters:

  • url (str) –

    请求地址。

  • data (dict) –

    请求数据,将被转为 JSON。

Returns:

  • dict

    响应结果的 JSON。

Source code in lazyllm/tools/rag/web.py
    def post_request(self, url, data):
        """
发送 POST 请求。

Args:
    url (str): 请求地址。
    data (dict): 请求数据,将被转为 JSON。

Returns:
    dict: 响应结果的 JSON。
"""
        response = requests.post(
            url, headers=self.basic_headers(), data=json.dumps(data)
        )
        return response.json()

upload_files(group_name, override=True)

向指定分组上传文件。

Parameters:

  • group_name (str) –

    分组名称。

  • override (bool, default: True ) –

    是否覆盖已存在的文件(默认 True)。

Returns:

  • Any

    后端返回的上传结果数据。

Source code in lazyllm/tools/rag/web.py
    def upload_files(self, group_name: str, override: bool = True):
        """
向指定分组上传文件。

Args:
    group_name (str): 分组名称。
    override (bool): 是否覆盖已存在的文件(默认 True)。

Returns:
    Any: 后端返回的上传结果数据。
"""
        response = requests.post(
            f"{self.base_url}/upload_files?group_name={group_name}&override={override}",
            headers=self.basic_headers(True),
        )
        return response.json()["data"]

lazyllm.tools.http_request.http_executor_response.HttpExecutorResponse

Source code in lazyllm/tools/http_request/http_executor_response.py
class HttpExecutorResponse:
    headers: dict[str, str]
    response: httpx.Response

    def __init__(self, response: httpx.Response = None):
        self.response = response
        self.headers = dict(response.headers) if isinstance(self.response, httpx.Response) else {}

    @property
    def is_file(self) -> bool:
        """
        check if response is file
        """
        content_type = self.get_content_type()
        file_content_types = ['image', 'audio', 'video']

        return any(v in content_type for v in file_content_types)

    def get_content_type(self) -> str:
        """获取HTTP响应的内容类型。

从响应头中提取 'content-type' 字段的值,用于判断响应内容的类型。

Returns:
    str: 响应的内容类型,如果未找到则返回空字符串。


Examples:
    >>> from lazyllm.tools.http_request.http_executor_response import HttpExecutorResponse
    >>> import httpx
    >>> response = httpx.Response(200, headers={'content-type': 'application/json'})
    >>> http_response = HttpExecutorResponse(response)
    >>> content_type = http_response.get_content_type()
    >>> print(content_type)
    ... 'application/json'
    """
        return self.headers.get('content-type', '')

    def extract_file(self) -> tuple[str, bytes]:
        """
        extract file from response if content type is file related
        """
        if self.is_file:
            return self.get_content_type(), self.body

        return '', b''

    @property
    def content(self) -> str:
        if isinstance(self.response, httpx.Response):
            return self.response.text
        else:
            raise ValueError(f'Invalid response type {type(self.response)}')

    @property
    def body(self) -> bytes:
        if isinstance(self.response, httpx.Response):
            return self.response.content
        else:
            raise ValueError(f'Invalid response type {type(self.response)}')

    @property
    def status_code(self) -> int:
        if isinstance(self.response, httpx.Response):
            return self.response.status_code
        else:
            raise ValueError(f'Invalid response type {type(self.response)}')

is_file property

check if response is file

extract_file()

extract file from response if content type is file related

Source code in lazyllm/tools/http_request/http_executor_response.py
def extract_file(self) -> tuple[str, bytes]:
    """
    extract file from response if content type is file related
    """
    if self.is_file:
        return self.get_content_type(), self.body

    return '', b''

get_content_type()

获取HTTP响应的内容类型。

从响应头中提取 'content-type' 字段的值,用于判断响应内容的类型。

Returns:

  • str ( str ) –

    响应的内容类型,如果未找到则返回空字符串。

Examples:

>>> from lazyllm.tools.http_request.http_executor_response import HttpExecutorResponse
>>> import httpx
>>> response = httpx.Response(200, headers={'content-type': 'application/json'})
>>> http_response = HttpExecutorResponse(response)
>>> content_type = http_response.get_content_type()
>>> print(content_type)
... 'application/json'
Source code in lazyllm/tools/http_request/http_executor_response.py
    def get_content_type(self) -> str:
        """获取HTTP响应的内容类型。

从响应头中提取 'content-type' 字段的值,用于判断响应内容的类型。

Returns:
    str: 响应的内容类型,如果未找到则返回空字符串。


Examples:
    >>> from lazyllm.tools.http_request.http_executor_response import HttpExecutorResponse
    >>> import httpx
    >>> response = httpx.Response(200, headers={'content-type': 'application/json'})
    >>> http_response = HttpExecutorResponse(response)
    >>> content_type = http_response.get_content_type()
    >>> print(content_type)
    ... 'application/json'
    """
        return self.headers.get('content-type', '')