1、万能计算器Agent

1.1 功能目标

使用LangChain开发一个万能计算器Agent,涵盖算术运算、矩阵运算、逻辑运算等多种功能。

1.2 结合LLM方案

引入LLM(大型语言模型)来解析用户输入并调用相应的计算方法,可以使万能计算器更加智能化和通用。我们可以使用预训练的模型(如Meta的Llama3)来实现自然语言理解,然后调用不同的计算函数。

2、实现步骤

2.1 安装依赖

首先,安装必要的依赖库:

pip install llama3 langchain numpy sympy

确保你已经在环境变量中设置了你的LLAMA3 API密钥。你可以在代码中直接设置,也可以在环境变量中配置。

import os
os.environ['LLAMA3_API_KEY'] = 'your_llama3_api_key'

2.2 定义万能计算器Agent

使用LangChain来构建一个包含多种计算功能的Agent。我们将使用以下库:

  • numpy 用于矩阵运算
  • sympy 用于符号计算和逻辑运算

2.2.1 创建Calculator类

这个类将包含算术运算、矩阵运算和逻辑运算的方法。

import numpy as np
import sympy as sp

class Calculator:
    def arithmetic(self, expression: str) -> float:
        try:
            result = eval(expression)
            return result
        except Exception as e:
            return str(e)
    
    def matrix_multiplication(self, matrix1: list, matrix2: list) -> np.ndarray:
        try:
            result = np.dot(matrix1, matrix2)
            return result
        except Exception as e:
            return str(e)
    
    def logical_operation(self, expression: str) -> bool:
        try:
            result = eval(expression)
            return result
        except Exception as e:
            return str(e)
    
    def evaluate_expression(self, expression: str):
        try:
            result = sp.sympify(expression)
            return result
        except Exception as e:
            return str(e)

2.2.2 创建LangChain Agent

定义一个Agent来处理用户的输入,使用LangChain和Llama3 API来处理用户输入并调用相应的计算方法。

from langchain import Chain
import llama3

class CalculatorAgent(Chain):
    def __init__(self):
        super().__init__()
        self.calculator = Calculator()
    
    def _call(self, inputs):
        query = inputs["query"]
        
        # 使用Llama3 API来解析输入并格式化
        response = llama3.Completion.create(
            engine="llama3",
            prompt=f"解析以下问题并格式化为正确的输入格式:\n{query}\n类型包括:算术运算、矩阵运算、逻辑运算。请提供适当的格式。",
            max_tokens=100
        )
        
        formatted_query = response.choices[0].text.strip()
        analysis = formatted_query.split(":")[0].strip().lower()
        formatted_expression = formatted_query.split(":")[1].strip()
        
        if "矩阵" in analysis:
            try:
                parts = formatted_expression.split("乘以")
                matrix1 = eval(parts[0].strip())
                matrix2 = eval(parts[1].strip())
                result = self.calculator.matrix_multiplication(matrix1, matrix2)
                return {"result": result}
            except Exception as e:
                return {"result": str(e)}
        elif "逻辑" in analysis:
            try:
                result = self.calculator.logical_operation(formatted_expression)
                return {"result": result}
            except Exception as e:
                return {"result": str(e)}
        else:
            try:
                result = self.calculator.arithmetic(formatted_expression)
                return {"result": result}
            except Exception as e:
                return {"result": str(e)}
    
    @property
    def input_keys(self):
        return ["query"]
    
    @property
    def output_keys(self):
        return ["result"]

2.2.3 运行Agent

创建一个实例并运行它。

if __name__ == "__main__":
    agent = CalculatorAgent()
    
    # Test arithmetic
    result = agent({"query": "2 + 2"})
    print("Arithmetic result:", result["result"])
    
    # Test matrix multiplication
    result = agent({"query": "矩阵 [[1, 2], [3, 4]] 乘以 [[5, 6], [7, 8]]"})
    print("Matrix multiplication result:\n", result["result"])
    
    # Test logical operation
    result = agent({"query": "True and False"})
    print("Logical operation result:", result["result"])

2.2.4 完整代码

将所有部分整合成一个完整的脚本:

import os
import numpy as np
import sympy as sp
from langchain import Chain
import llama3

os.environ['LLAMA3_API_KEY'] = 'your_llama3_api_key'

class Calculator:
    def arithmetic(self, expression: str) -> float:
        try:
            result = eval(expression)
            return result
        except Exception as e:
            return str(e)
    
    def matrix_multiplication(self, matrix1: list, matrix2: list) -> np.ndarray:
        try:
            result = np.dot(matrix1, matrix2)
            return result
        except Exception as e:
            return str(e)
    
    def logical_operation(self, expression: str) -> bool:
        try:
            result = eval(expression)
            return result
        except Exception as e:
            return str(e)
    
    def evaluate_expression(self, expression: str):
        try:
            result = sp.sympify(expression)
            return result
        except Exception as e:
            return str(e)

class CalculatorAgent(Chain):
    def __init__(self):
        super().__init__()
        self.calculator = Calculator()
    
    def _call(self, inputs):
        query = inputs["query"]
        
        # 使用Llama3 API来解析输入并格式化
        response = llama3.Completion.create(
            engine="llama3",
            prompt=f"解析以下问题并格式化为正确的输入格式:\n{query}\n类型包括:算术运算、矩阵运算、逻辑运算。请提供适当的格式。",
            max_tokens=100
        )
        
        formatted_query = response.choices[0].text.strip()
        analysis = formatted_query.split(":")[0].strip().lower()
        formatted_expression = formatted_query.split(":")[1].strip()
        
        if "矩阵" in analysis:
            try:
                parts = formatted_expression.split("乘以")
                matrix1 = eval(parts[0].strip())
                matrix2 = eval(parts[1].strip())
                result = self.calculator.matrix_multiplication(matrix1, matrix2)
                return {"result": result}
            except Exception as e:
                return {"result": str(e)}
        elif "逻辑" in analysis:
            try:
                result = self.calculator.logical_operation(formatted_expression)
                return {"result": result}
            except Exception as e:
                return {"result": str(e)}
        else:
            try:
                result = self.calculator.arithmetic(formatted_expression)
                return {"result": result}
            except Exception as e:
                return {"result": str(e)}
    
    @property
    def input_keys(self):
        return ["query"]
    
    @property
    def output_keys(self):
        return ["result"]

if __name__ == "__main__":
    agent = CalculatorAgent()
    
    # Test arithmetic
    result = agent({"query": "2 + 2"})
    print("Arithmetic result:", result["result"])
    
    # Test matrix multiplication
    result = agent({"query": "矩阵 [[1, 2], [3, 4]] 乘以 [[5, 6], [7, 8]]"})
    print("Matrix multiplication result:\n", result["result"])
    
    # Test logical operation
    result = agent({"query": "True and False"})
    print("Logical operation result:", result["result"])

通过以上步骤,你可以构建一个智能的万能计算器Agent,能够自动解析用户输入并执行相应的计算任务。根据实际需求,你可以进一步扩展这个Agent的功能。

3、使Agent更加智能

3.1 优化目标

假设用户的输入是这样的:

矩阵1=[[1, 2], [3, 4]],矩阵2=[[5, 6], [7, 8]],请计算他们的乘积

要求“万能计算器Agent”依然能够准确识别并计算。

3.2 实现方案

为了让Llama3能够正确理解并计算用户输入的矩阵乘积,我们可以考虑进行以下步骤:

  1. 微调Llama3:虽然Llama3可能已经具备一定的理解能力,但通过微调,我们可以让它更好地理解特定格式的输入和任务。微调的过程涉及准备数据集和训练模型。
  2. 使用预处理和模板:如果不进行微调,可以通过预处理输入并使用模板来提高模型的理解能力。

在此,我们首先尝试使用预处理和模板的方法。如果效果不佳,可以进一步考虑微调模型。

3.3 预处理输入方案

我们可以预处理用户的输入,将其格式化为更易解析的形式。

import os
import numpy as np
import sympy as sp
from langchain import Chain
import llama3

os.environ['LLAMA3_API_KEY'] = 'your_llama3_api_key'

class Calculator:
    def arithmetic(self, expression: str) -> float:
        try:
            result = eval(expression)
            return result
        except Exception as e:
            return str(e)
    
    def matrix_multiplication(self, matrix1: list, matrix2: list) -> np.ndarray:
        try:
            result = np.dot(matrix1, matrix2)
            return result
        except Exception as e:
            return str(e)
    
    def logical_operation(self, expression: str) -> bool:
        try:
            result = eval(expression)
            return result
        except Exception as e:
            return str(e)
    
    def evaluate_expression(self, expression: str):
        try:
            result = sp.sympify(expression)
            return result
        except Exception as e:
            return str(e)

class CalculatorAgent(Chain):
    def __init__(self):
        super().__init__()
        self.calculator = Calculator()
    
    def _call(self, inputs):
        query = inputs["query"]
        
        # 使用Llama3 API来解析输入并格式化
        response = llama3.Completion.create(
            engine="llama3",
            prompt=f"请解析以下问题并格式化为正确的输入格式:\n{query}\n类型包括:算术运算、矩阵运算、逻辑运算。请提供适当的格式。",
            max_tokens=100
        )
        
        formatted_query = response.choices[0].text.strip()
        analysis = formatted_query.split(":")[0].strip().lower()
        formatted_expression = formatted_query.split(":")[1].strip()
        
        if "矩阵" in analysis:
            try:
                # 将格式化后的矩阵解析并计算乘积
                parts = formatted_expression.split("乘以")
                matrix1 = eval(parts[0].strip())
                matrix2 = eval(parts[1].strip())
                result = self.calculator.matrix_multiplication(matrix1, matrix2)
                return {"result": result}
            except Exception as e:
                return {"result": str(e)}
        elif "逻辑" in analysis:
            try:
                result = self.calculator.logical_operation(formatted_expression)
                return {"result": result}
            except Exception as e:
                return {"result": str(e)}
        else:
            try:
                result = self.calculator.arithmetic(formatted_expression)
                return {"result": result}
            except Exception as e:
                return {"result": str(e)}
    
    @property
    def input_keys(self):
        return ["query"]
    
    @property
    def output_keys(self):
        return ["result"]

if __name__ == "__main__":
    agent = CalculatorAgent()
    
    # Test arithmetic
    result = agent({"query": "2 + 2"})
    print("Arithmetic result:", result["result"])
    
    # Test matrix multiplication
    result = agent({"query": "矩阵1=[[1, 2], [3, 4]],矩阵2=[[5, 6], [7, 8]],请计算他们的乘积。"})
    print("Matrix multiplication result:\n", result["result"])
    
    # Test logical operation
    result = agent({"query": "True and False"})
    print("Logical operation result:", result["result"])

3.4 微调方案

如果预处理方法无法完全满足需求,微调Llama3将是一个更好的解决方案。微调过程包括以下步骤:

3.4.1 准备数据集

准备一个包含各种输入示例的数据集,格式如下:

[
    {"input": "矩阵1=[[1, 2], [3, 4]],矩阵2=[[5, 6], [7, 8]],请计算他们的乘积。", "output": "矩阵运算: 矩阵1=[[1, 2], [3, 4]] 乘以 矩阵2=[[5, 6], [7, 8]]"},
    {"input": "2 + 2", "output": "算术运算: 2 + 2"},
    {"input": "True and False", "output": "逻辑运算: True and False"}
]

3.4.2 进行微调

使用Llama3的微调工具进行模型微调。具体的微调代码和工具可以参考Llama3的官方文档

3.4.3 使用微调后的模型

使用微调后的模型来解析用户输入并进行计算。

# 假设微调后的模型已经训练好并保存
import llama3

llama3.api_key = 'your_llama3_api_key'

class CalculatorAgent(Chain):
    def __init__(self):
        super().__init__()
        self.calculator = Calculator()
    
    def _call(self, inputs):
        query = inputs["query"]
        
        # 使用微调后的Llama3模型来解析输入并格式化
        response = llama3.Completion.create(
            engine="fine-tuned-llama3",
            prompt=f"请解析以下问题并格式化为正确的输入格式:\n{query}\n类型包括:算术运算、矩阵运算、逻辑运算。请提供适当的格式。",
            max_tokens=100
        )
        
        formatted_query = response.choices[0].text.strip()
        analysis = formatted_query.split(":")[0].strip().lower()
        formatted_expression = formatted_query.split(":")[1].strip()
        
        if "矩阵" in analysis:
            try:
                parts = formatted_expression.split("乘以")
                matrix1 = eval(parts[0].strip())
                matrix2 = eval(parts[1].strip())
                result = self.calculator.matrix_multiplication(matrix1, matrix2)
                return {"result": result}
            except Exception as e:
                return {"result": str(e)}
        elif "逻辑" in analysis:
            try:
                result = self.calculator.logical_operation(formatted_expression)
                return {"result": result}
            except Exception as e:
                return {"result": str(e)}
        else:
            try:
                result = self.calculator.arithmetic(formatted_expression)
                return {"result": result}
            except Exception as e:
                return {"result": str(e)}
    
    @property
    def input_keys(self):
        return ["query"]
    
    @property
    def output_keys(self):
        return ["result"]

通过以上步骤,你可以构建一个更智能的万能计算器Agent,能够自动解析用户输入并执行相应的计算任务。微调Llama3可以进一步提升模型的理解和处理能力。

4、补充说明

本文只是提供一种研发思路,文中代码并没有在真机上运行通过,大家可以根据需要自行修改。