Add RWKV support for Agent
Browse files
agents.py
CHANGED
|
@@ -13,6 +13,15 @@ from langchain_community.document_loaders import ArxivLoader
|
|
| 13 |
from langchain_community.vectorstores import SupabaseVectorStore
|
| 14 |
from langchain_core.messages import SystemMessage, HumanMessage
|
| 15 |
from langchain_core.tools import tool
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
load_dotenv()
|
| 18 |
|
|
@@ -144,7 +153,7 @@ tools = [
|
|
| 144 |
]
|
| 145 |
|
| 146 |
# Build graph function
|
| 147 |
-
def build_graph(provider: str = "
|
| 148 |
"""Build the graph"""
|
| 149 |
# Load environment variables from .env file
|
| 150 |
if provider == "google":
|
|
@@ -161,6 +170,31 @@ def build_graph(provider: str = "groq"):
|
|
| 161 |
temperature=0,
|
| 162 |
),
|
| 163 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
else:
|
| 165 |
raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
|
| 166 |
# Bind tools to LLM
|
|
|
|
| 13 |
from langchain_community.vectorstores import SupabaseVectorStore
|
| 14 |
from langchain_core.messages import SystemMessage, HumanMessage
|
| 15 |
from langchain_core.tools import tool
|
| 16 |
+
from huggingface_hub import hf_hub_download
|
| 17 |
+
|
| 18 |
+
# RWKV setup flags — must come before importing RWKV
|
| 19 |
+
os.environ["RWKV_JIT_ON"] = "1"
|
| 20 |
+
os.environ["RWKV_V7_ON"] = "1" # enable RWKV-7
|
| 21 |
+
os.environ["RWKV_CUDA_ON"] = "0" # set to "1"
|
| 22 |
+
|
| 23 |
+
from rwkv.model import RWKV
|
| 24 |
+
from rwkv.utils import PIPELINE
|
| 25 |
|
| 26 |
load_dotenv()
|
| 27 |
|
|
|
|
| 153 |
]
|
| 154 |
|
| 155 |
# Build graph function
|
| 156 |
+
def build_graph(provider: str = "rwkv"):
|
| 157 |
"""Build the graph"""
|
| 158 |
# Load environment variables from .env file
|
| 159 |
if provider == "google":
|
|
|
|
| 170 |
temperature=0,
|
| 171 |
),
|
| 172 |
)
|
| 173 |
+
elif provider == "rwkv":
|
| 174 |
+
# --- BEGIN RWKV SETUP ---
|
| 175 |
+
title = "rwkv7-g1"
|
| 176 |
+
pth = hf_hub_download(repo_id="BlinkDL/rwkv7-g1", filename=f"{title}.pth")
|
| 177 |
+
# 2) Load RWKV (drop .pth extension for RWKV loader)
|
| 178 |
+
rwkv_model = RWKV(model=pth.replace(".pth", ""), strategy="cpu fp32")
|
| 179 |
+
# 3) Build the tokenization + generation pipeline
|
| 180 |
+
rwkv_pipe = PIPELINE(rwkv_model, "rwkv_vocab_v20230424")
|
| 181 |
+
# 4) Wrap into a Chat-style interface
|
| 182 |
+
class ChatRWKV:
|
| 183 |
+
def __init__(self, pipe):
|
| 184 |
+
self.pipe = pipe
|
| 185 |
+
def invoke(self, messages):
|
| 186 |
+
prompt = "\n".join(m.content for m in messages)
|
| 187 |
+
return self.pipe(
|
| 188 |
+
prompt,
|
| 189 |
+
temperature=0.0,
|
| 190 |
+
top_p=0.95,
|
| 191 |
+
max_tokens=256,
|
| 192 |
+
)
|
| 193 |
+
def bind_tools(self, tools):
|
| 194 |
+
return self
|
| 195 |
+
|
| 196 |
+
llm = ChatRWKV(rwkv_pipe)
|
| 197 |
+
# --- END RWKV SETUP ---
|
| 198 |
else:
|
| 199 |
raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
|
| 200 |
# Bind tools to LLM
|