railtracks
The Railtracks Framework for building resilient agentic systems in simple python
1# ------------------------------------------------------------- 2# Copyright (c) Railtown AI. All rights reserved. 3# Licensed under the MIT License. See LICENSE in project root for information. 4# ------------------------------------------------------------- 5"""The Railtracks Framework for building resilient agentic systems in simple python""" 6 7from __future__ import annotations 8 9import logging 10 11from dotenv import load_dotenv 12 13__all__ = [ 14 "Session", 15 "session", 16 "call", 17 "broadcast", 18 "call_batch", 19 "interactive", 20 "ExecutionInfo", 21 "ExecutorConfig", 22 "llm", 23 "context", 24 "set_config", 25 "context", 26 "function_node", 27 "agent_node", 28 "integrations", 29 "prebuilt", 30 "MCPStdioParams", 31 "MCPHttpParams", 32 "connect_mcp", 33 "create_mcp_server", 34 "ToolManifest", 35 "session_id", 36 "vector_stores", 37 "rag", 38 "RagConfig", 39 "Flow", 40 "enable_logging", 41] 42 43from railtracks.built_nodes.concrete.rag import RagConfig 44from railtracks.built_nodes.easy_usage_wrappers import ( 45 agent_node, 46 function_node, 47) 48 49from . import context, integrations, llm, prebuilt, rag, vector_stores 50from ._session import ExecutionInfo, Session, session 51from .context.central import session_id, set_config 52from .interaction import broadcast, call, call_batch, interactive 53from .nodes.manifest import ToolManifest 54from .orchestration.flow import Flow 55from .rt_mcp import MCPHttpParams, MCPStdioParams, connect_mcp, create_mcp_server 56from .utils.config import ExecutorConfig 57from .utils.logging.config import enable_logging 58 59load_dotenv() 60 61# Library does not configure logging by default. Add NullHandler so the RT logger 62# never emits "No handlers could be found". Call enable_logging() to opt in. 63logging.getLogger("RT").addHandler(logging.NullHandler()) 64 65# Do not worry about changing this version number manually. It will updated on release. 66__version__ = "1.0.0"
43class Session: 44 """ 45 The main class for managing an execution session. 46 47 This class is responsible for setting up all the necessary components for running a Railtracks execution, including the coordinator, publisher, and state management. 48 49 For the configuration parameters of the setting. It will follow this precedence: 50 1. The parameters in the `Session` constructor. 51 2. The parameters in global context variables. 52 3. The default values. 53 54 Default Values: 55 - `name`: None 56 - `timeout`: 150.0 seconds 57 - `end_on_error`: False 58 - `logging_setting`: "INFO" 59 - `log_file`: None (logs will not be written to a file) 60 - `broadcast_callback`: None (no callback for broadcast messages) 61 - `prompt_injection`: True (the prompt will be automatically injected from context variables) 62 - `save_state`: True (the state of the execution will be saved to a file at the end of the run in the `.railtracks/data/sessions/` directory) 63 64 65 Args: 66 name (str | None, optional): Optional name for the session. This name will be included in the saved state file if `save_state` is True. 67 context (Dict[str, Any], optional): A dictionary of global context variables to be used during the execution. 68 flow_name (str | None, optional): The name of the flow this session is associated with. 69 flow_id (str | None, optional): The unique identifier of the flow this session is associated with. 70 timeout (float, optional): The maximum number of seconds to wait for a response to your top-level request. 71 end_on_error (bool, optional): If True, the execution will stop when an exception is encountered. 72 logging_setting (AllowableLogLevels, optional): The setting for the level of logging you would like to have. This will override the module-level logging settings for the duration of this session. 73 log_file (str | os.PathLike | None, optional): The file to which the logs will be written. 74 broadcast_callback (Callable[[str], None] | Callable[[str], Coroutine[None, None, None]] | None, optional): A callback function that will be called with the broadcast messages. 75 prompt_injection (bool, optional): If True, the prompt will be automatically injected from context variables. 76 save_state (bool, optional): If True, the state of the execution will be saved to a file at the end of the run in the `.railtracks/data/sessions/` directory. 77 """ 78 79 def __init__( 80 self, 81 context: Dict[str, Any] | None = None, 82 *, 83 flow_name: str | None = None, 84 flow_id: str | None = None, 85 name: str | None = None, 86 timeout: float | None = None, 87 end_on_error: bool | None = None, 88 logging_setting: AllowableLogLevels | None = None, 89 log_file: str | os.PathLike | None = None, 90 broadcast_callback: ( 91 Callable[[str], None] | Callable[[str], Coroutine[None, None, None]] | None 92 ) = None, 93 prompt_injection: bool | None = None, 94 save_state: bool | None = None, 95 payload_callback: Callable[[dict[str, Any]], None] | None = None, 96 ): 97 # first lets read from defaults if nessecary for the provided input config 98 99 if flow_name is None: 100 warnings.warn( 101 "Sessions should be tied to a flow for better observability and state management. Please use the Flow object to create and manage your sessions (see __ for more details). This warning will become an error in future versions.", 102 DeprecationWarning, 103 ) 104 105 self.executor_config = self.global_config_precedence( 106 timeout=timeout, 107 end_on_error=end_on_error, 108 logging_setting=logging_setting, 109 log_file=log_file, 110 broadcast_callback=broadcast_callback, 111 prompt_injection=prompt_injection, 112 save_state=save_state, 113 payload_callback=payload_callback, 114 ) 115 116 if context is None: 117 context = {} 118 119 self.name = name 120 self.flow_name = flow_name 121 self.flow_id = flow_id 122 123 self._has_custom_logging = logging_setting is not None or log_file is not None 124 125 if self._has_custom_logging: 126 mark_session_logging_override( 127 session_level=self.executor_config.logging_setting, 128 session_log_file=self.executor_config.log_file, 129 ) 130 131 self.publisher: RTPublisher = RTPublisher() 132 133 self._identifier = str(uuid.uuid4()) 134 135 executor_info = ExecutionInfo.create_new() 136 self.coordinator = Coordinator( 137 execution_modes={"async": AsyncioExecutionStrategy()} 138 ) 139 self.rt_state = RTState( 140 executor_info, self.executor_config, self.coordinator, self.publisher 141 ) 142 143 self.coordinator.start(self.publisher) 144 self._setup_subscriber() 145 register_globals( 146 session_id=self._identifier, 147 rt_publisher=self.publisher, 148 parent_id=None, 149 executor_config=self.executor_config, 150 global_context_vars=context, 151 ) 152 153 self._start_time = time.time() 154 155 logger.debug("Session %s is initialized" % self._identifier) 156 157 @classmethod 158 def global_config_precedence( 159 cls, 160 timeout: float | None, 161 end_on_error: bool | None, 162 logging_setting: AllowableLogLevels | None, 163 log_file: str | os.PathLike | None, 164 broadcast_callback: ( 165 Callable[[str], None] | Callable[[str], Coroutine[None, None, None]] | None 166 ), 167 prompt_injection: bool | None, 168 save_state: bool | None, 169 payload_callback: Callable[[dict[str, Any]], None] | None, 170 ) -> ExecutorConfig: 171 """ 172 Uses the following precedence order to determine the configuration parameters: 173 1. The parameters in the method parameters. 174 2. The parameters in global context variables. 175 3. The default values. 176 """ 177 global_executor_config = get_global_config() 178 179 return global_executor_config.precedence_overwritten( 180 timeout=timeout, 181 end_on_error=end_on_error, 182 logging_setting=logging_setting, 183 log_file=log_file, 184 subscriber=broadcast_callback, 185 prompt_injection=prompt_injection, 186 save_state=save_state, 187 payload_callback=payload_callback, 188 ) 189 190 def __enter__(self): 191 return self 192 193 def __exit__(self, exc_type, exc_val, exc_tb): 194 if self.executor_config.save_state: 195 try: 196 railtracks_dir = Path(".railtracks") 197 sessions_dir = railtracks_dir / "data" / "sessions" 198 sessions_dir.mkdir( 199 parents=True, exist_ok=True 200 ) # Creates directory structure if doesn't exist, skips otherwise. 201 202 # Try to create file path with name, fallback to identifier only if there's an issue 203 if self.flow_name is not None: 204 name = self.flow_name 205 elif self.name is not None: 206 name = self.name 207 else: 208 name = "" 209 210 try: 211 file_path = sessions_dir / f"{name}_{self._identifier}.json" 212 file_path.touch() 213 except FileNotFoundError: 214 logger.warning( 215 get_message( 216 ExceptionMessageKey.INVALID_SESSION_FILE_NAME_WARN 217 ).format(name=name, identifier=self._identifier) 218 ) 219 file_path = sessions_dir / f"{self._identifier}.json" 220 221 logger.info("Saving execution info to %s" % file_path) 222 223 file_path.write_text(json.dumps(self.payload())) 224 225 except Exception as e: 226 logger.error( 227 "Error while saving to execution info to file", 228 exc_info=e, 229 ) 230 try: 231 if self.executor_config.payload_callback is not None: 232 self.executor_config.payload_callback(self.payload()) 233 except Exception: 234 # TODO: add logging here. 235 pass 236 237 self._close() 238 239 def _setup_subscriber(self): 240 """ 241 Prepares and attaches the saved broadcast_callback to the publisher attached to this runner. 242 """ 243 244 if self.executor_config.subscriber is not None: 245 self.publisher.subscribe( 246 stream_subscriber(self.executor_config.subscriber), 247 name="Streaming Subscriber", 248 ) 249 250 def _close(self): 251 """ 252 Closes the runner and cleans up all resources. 253 254 - Shuts down the state object 255 - Detaches logging handlers so they aren't duplicated 256 - Deletes all the global variables that were registered in the context 257 """ 258 # FIX: Resource leak - publisher background task wasn't being shut down on Session exit 259 # VISION: Session owns publisher lifecycle and must clean up all resources when exiting 260 if self.publisher.is_running(): 261 try: 262 # Signal shutdown by setting the flag - the loop will check this and exit 263 self.publisher._running = False 264 265 # Try to cancel the background task if it exists and isn't done 266 if ( 267 self.publisher.pub_loop is not None 268 and not self.publisher.pub_loop.done() 269 ): 270 try: 271 # Cancel the task - it will check _running and exit naturally 272 self.publisher.pub_loop.cancel() 273 except Exception: 274 # Task might be done or in a different loop, that's okay 275 pass 276 except Exception: 277 # If shutdown fails for any reason, log it but don't crash 278 logger.warning( 279 "Failed to shutdown publisher during Session cleanup. " 280 "This may indicate a resource leak.", 281 exc_info=True, 282 ) 283 284 self.rt_state.shutdown() 285 286 if self._has_custom_logging: 287 restore_module_logging() 288 289 delete_globals() 290 # by deleting all of the state variables we are ensuring that the next time we create a runner it is fresh 291 292 @property 293 def info(self) -> ExecutionInfo: 294 """ 295 Returns the current state of the runner. 296 297 This is useful for debugging and viewing the current state of the run. 298 """ 299 return self.rt_state.info 300 301 def payload(self) -> Dict[str, Any]: 302 """ 303 Gets the complete json payload tied to this session. 304 305 The outputted json schema is maintained in (link here) 306 """ 307 info = self.info 308 309 run_list = info.graph_serialization() 310 311 full_dict = { 312 "flow_name": self.flow_name, 313 "flow_id": self.flow_id, 314 "session_id": self._identifier, 315 "session_name": self.name, 316 "start_time": self._start_time, 317 "end_time": time.time(), 318 "runs": run_list, 319 } 320 321 return json.loads(json.dumps(full_dict))
The main class for managing an execution session.
This class is responsible for setting up all the necessary components for running a Railtracks execution, including the coordinator, publisher, and state management.
For the configuration parameters of the setting. It will follow this precedence:
- The parameters in the
Sessionconstructor. - The parameters in global context variables.
- The default values.
Default Values:
name: Nonetimeout: 150.0 secondsend_on_error: Falselogging_setting: "INFO"log_file: None (logs will not be written to a file)broadcast_callback: None (no callback for broadcast messages)prompt_injection: True (the prompt will be automatically injected from context variables)save_state: True (the state of the execution will be saved to a file at the end of the run in the.railtracks/data/sessions/directory)
Arguments:
- name (str | None, optional): Optional name for the session. This name will be included in the saved state file if
save_stateis True. - context (Dict[str, Any], optional): A dictionary of global context variables to be used during the execution.
- flow_name (str | None, optional): The name of the flow this session is associated with.
- flow_id (str | None, optional): The unique identifier of the flow this session is associated with.
- timeout (float, optional): The maximum number of seconds to wait for a response to your top-level request.
- end_on_error (bool, optional): If True, the execution will stop when an exception is encountered.
- logging_setting (AllowableLogLevels, optional): The setting for the level of logging you would like to have. This will override the module-level logging settings for the duration of this session.
- log_file (str | os.PathLike | None, optional): The file to which the logs will be written.
- broadcast_callback (Callable[[str], None] | Callable[[str], Coroutine[None, None, None]] | None, optional): A callback function that will be called with the broadcast messages.
- prompt_injection (bool, optional): If True, the prompt will be automatically injected from context variables.
- save_state (bool, optional): If True, the state of the execution will be saved to a file at the end of the run in the
.railtracks/data/sessions/directory.
79 def __init__( 80 self, 81 context: Dict[str, Any] | None = None, 82 *, 83 flow_name: str | None = None, 84 flow_id: str | None = None, 85 name: str | None = None, 86 timeout: float | None = None, 87 end_on_error: bool | None = None, 88 logging_setting: AllowableLogLevels | None = None, 89 log_file: str | os.PathLike | None = None, 90 broadcast_callback: ( 91 Callable[[str], None] | Callable[[str], Coroutine[None, None, None]] | None 92 ) = None, 93 prompt_injection: bool | None = None, 94 save_state: bool | None = None, 95 payload_callback: Callable[[dict[str, Any]], None] | None = None, 96 ): 97 # first lets read from defaults if nessecary for the provided input config 98 99 if flow_name is None: 100 warnings.warn( 101 "Sessions should be tied to a flow for better observability and state management. Please use the Flow object to create and manage your sessions (see __ for more details). This warning will become an error in future versions.", 102 DeprecationWarning, 103 ) 104 105 self.executor_config = self.global_config_precedence( 106 timeout=timeout, 107 end_on_error=end_on_error, 108 logging_setting=logging_setting, 109 log_file=log_file, 110 broadcast_callback=broadcast_callback, 111 prompt_injection=prompt_injection, 112 save_state=save_state, 113 payload_callback=payload_callback, 114 ) 115 116 if context is None: 117 context = {} 118 119 self.name = name 120 self.flow_name = flow_name 121 self.flow_id = flow_id 122 123 self._has_custom_logging = logging_setting is not None or log_file is not None 124 125 if self._has_custom_logging: 126 mark_session_logging_override( 127 session_level=self.executor_config.logging_setting, 128 session_log_file=self.executor_config.log_file, 129 ) 130 131 self.publisher: RTPublisher = RTPublisher() 132 133 self._identifier = str(uuid.uuid4()) 134 135 executor_info = ExecutionInfo.create_new() 136 self.coordinator = Coordinator( 137 execution_modes={"async": AsyncioExecutionStrategy()} 138 ) 139 self.rt_state = RTState( 140 executor_info, self.executor_config, self.coordinator, self.publisher 141 ) 142 143 self.coordinator.start(self.publisher) 144 self._setup_subscriber() 145 register_globals( 146 session_id=self._identifier, 147 rt_publisher=self.publisher, 148 parent_id=None, 149 executor_config=self.executor_config, 150 global_context_vars=context, 151 ) 152 153 self._start_time = time.time() 154 155 logger.debug("Session %s is initialized" % self._identifier)
157 @classmethod 158 def global_config_precedence( 159 cls, 160 timeout: float | None, 161 end_on_error: bool | None, 162 logging_setting: AllowableLogLevels | None, 163 log_file: str | os.PathLike | None, 164 broadcast_callback: ( 165 Callable[[str], None] | Callable[[str], Coroutine[None, None, None]] | None 166 ), 167 prompt_injection: bool | None, 168 save_state: bool | None, 169 payload_callback: Callable[[dict[str, Any]], None] | None, 170 ) -> ExecutorConfig: 171 """ 172 Uses the following precedence order to determine the configuration parameters: 173 1. The parameters in the method parameters. 174 2. The parameters in global context variables. 175 3. The default values. 176 """ 177 global_executor_config = get_global_config() 178 179 return global_executor_config.precedence_overwritten( 180 timeout=timeout, 181 end_on_error=end_on_error, 182 logging_setting=logging_setting, 183 log_file=log_file, 184 subscriber=broadcast_callback, 185 prompt_injection=prompt_injection, 186 save_state=save_state, 187 payload_callback=payload_callback, 188 )
Uses the following precedence order to determine the configuration parameters:
- The parameters in the method parameters.
- The parameters in global context variables.
- The default values.
292 @property 293 def info(self) -> ExecutionInfo: 294 """ 295 Returns the current state of the runner. 296 297 This is useful for debugging and viewing the current state of the run. 298 """ 299 return self.rt_state.info
Returns the current state of the runner.
This is useful for debugging and viewing the current state of the run.
301 def payload(self) -> Dict[str, Any]: 302 """ 303 Gets the complete json payload tied to this session. 304 305 The outputted json schema is maintained in (link here) 306 """ 307 info = self.info 308 309 run_list = info.graph_serialization() 310 311 full_dict = { 312 "flow_name": self.flow_name, 313 "flow_id": self.flow_id, 314 "session_id": self._identifier, 315 "session_name": self.name, 316 "start_time": self._start_time, 317 "end_time": time.time(), 318 "runs": run_list, 319 } 320 321 return json.loads(json.dumps(full_dict))
Gets the complete json payload tied to this session.
The outputted json schema is maintained in (link here)
382def session( 383 func: Callable[_P, Coroutine[Any, Any, _TOutput]] | None = None, 384 *, 385 name: str | None = None, 386 context: Dict[str, Any] | None = None, 387 timeout: float | None = None, 388 end_on_error: bool | None = None, 389 logging_setting: AllowableLogLevels | None = None, 390 log_file: str | os.PathLike | None = None, 391 broadcast_callback: ( 392 Callable[[str], None] | Callable[[str], Coroutine[None, None, None]] | None 393 ) = None, 394 prompt_injection: bool | None = None, 395 save_state: bool | None = None, 396) -> ( 397 Callable[_P, Coroutine[Any, Any, Tuple[_TOutput, Session]]] 398 | Callable[ 399 [Callable[_P, Coroutine[Any, Any, _TOutput]]], 400 Callable[_P, Coroutine[Any, Any, Tuple[_TOutput, Session]]], 401 ] 402): 403 """ 404 This decorator automatically creates and manages a Session context for the decorated function, 405 allowing async functions to use Railtracks operations without manually managing the session lifecycle. 406 407 Can be used as: 408 - @session (without parentheses) - uses default settings 409 - @session() (with empty parentheses) - uses default settings 410 - @session(name="my_task", timeout=30) (with configuration parameters) 411 412 When using this decorator, the function returns a tuple containing: 413 1. The original function's return value 414 2. The Session object used during execution 415 416 This allows access to session information (like execution state, logs, etc.) after the function completes, 417 while maintaining the simplicity of decorator usage. 418 419 Args: 420 name (str | None, optional): Optional name for the session. This name will be included in the saved state file if `save_state` is True. 421 context (Dict[str, Any], optional): A dictionary of global context variables to be used during the execution. 422 timeout (float, optional): The maximum number of seconds to wait for a response to your top-level request. 423 end_on_error (bool, optional): If True, the execution will stop when an exception is encountered. 424 logging_setting (AllowableLogLevels, optional): The setting for the level of logging you would like to have. This will override the module-level logging settings for the duration of this session. 425 log_file (str | os.PathLike | None, optional): The file to which the logs will be written. 426 broadcast_callback (Callable[[str], None] | Callable[[str], Coroutine[None, None, None]] | None, optional): A callback function that will be called with the broadcast messages. 427 prompt_injection (bool, optional): If True, the prompt will be automatically injected from context variables. 428 save_state (bool, optional): If True, the state of the execution will be saved to a file at the end of the run in the `.railtracks/data/sessions/` directory. 429 430 Returns: 431 When used as @session (without parentheses): Returns the decorated function that returns (result, session). 432 When used as @session(...) (with parameters): Returns a decorator function that takes an async function 433 and returns a new async function that returns (result, session). 434 """ 435 436 def decorator( 437 target_func: Callable[_P, Coroutine[Any, Any, _TOutput]], 438 ) -> Callable[_P, Coroutine[Any, Any, Tuple[_TOutput, Session]]]: 439 # Validate that the decorated function is async 440 if not inspect.iscoroutinefunction(target_func): 441 raise TypeError( 442 f"@session decorator can only be applied to async functions. " 443 f"Function '{target_func.__name__}' is not async. " 444 f"Add 'async' keyword to your function definition." 445 ) 446 447 @wraps(target_func) 448 async def wrapper( 449 *args: _P.args, **kwargs: _P.kwargs 450 ) -> Tuple[_TOutput, Session]: 451 session_obj = Session( 452 context=context, 453 timeout=timeout, 454 end_on_error=end_on_error, 455 logging_setting=logging_setting, 456 log_file=log_file, 457 broadcast_callback=broadcast_callback, 458 name=name, 459 prompt_injection=prompt_injection, 460 save_state=save_state, 461 ) 462 463 with session_obj: 464 result = await target_func(*args, **kwargs) 465 return result, session_obj 466 467 return wrapper 468 469 # If used as @session without parentheses 470 if func is not None: 471 return decorator(func) 472 473 # If used as @session(...) 474 return decorator
This decorator automatically creates and manages a Session context for the decorated function, allowing async functions to use Railtracks operations without manually managing the session lifecycle.
Can be used as:
- @session (without parentheses) - uses default settings
- @session() (with empty parentheses) - uses default settings
- @session(name="my_task", timeout=30) (with configuration parameters)
When using this decorator, the function returns a tuple containing:
- The original function's return value
- The Session object used during execution
This allows access to session information (like execution state, logs, etc.) after the function completes, while maintaining the simplicity of decorator usage.
Arguments:
- name (str | None, optional): Optional name for the session. This name will be included in the saved state file if
save_stateis True. - context (Dict[str, Any], optional): A dictionary of global context variables to be used during the execution.
- timeout (float, optional): The maximum number of seconds to wait for a response to your top-level request.
- end_on_error (bool, optional): If True, the execution will stop when an exception is encountered.
- logging_setting (AllowableLogLevels, optional): The setting for the level of logging you would like to have. This will override the module-level logging settings for the duration of this session.
- log_file (str | os.PathLike | None, optional): The file to which the logs will be written.
- broadcast_callback (Callable[[str], None] | Callable[[str], Coroutine[None, None, None]] | None, optional): A callback function that will be called with the broadcast messages.
- prompt_injection (bool, optional): If True, the prompt will be automatically injected from context variables.
- save_state (bool, optional): If True, the state of the execution will be saved to a file at the end of the run in the
.railtracks/data/sessions/directory.
Returns:
When used as @session (without parentheses): Returns the decorated function that returns (result, session). When used as @session(...) (with parameters): Returns a decorator function that takes an async function and returns a new async function that returns (result, session).
59async def call( 60 node_: Callable[_P, Node[_TOutput]] | RTFunction[_P, _TOutput], 61 *args: _P.args, 62 **kwargs: _P.kwargs, 63) -> _TOutput: 64 """ 65 Call a node from within a node inside the framework. This will return a coroutine that you can interact with 66 in whatever way using async/await logic. 67 68 Usage: 69 ```python 70 # for sequential operation 71 result = await call(NodeA, "hello world", 42) 72 73 # for parallel operation 74 tasks = [call(NodeA, "hello world", i) for i in range(10)] 75 results = await asyncio.gather(*tasks) 76 ``` 77 78 Args: 79 node: The node type you would like to create. This could be a function decorated with `@function_node`, a function, or a Node instance. 80 *args: The arguments to pass to the node 81 **kwargs: The keyword arguments to pass to the node 82 """ 83 node: Callable[_P, Node[_TOutput]] 84 # this entire section is a bit of a typing nightmare becuase all overloads we provide. 85 if isinstance(node_, FunctionType): 86 node = extract_node_from_function(node_) 87 else: 88 node = node_ 89 # if the context is none then we will need to create a wrapper for the state object to work with. 90 if not is_context_present(): 91 # we have to use lazy import here to prevent a circular import issue. This is a must have unfortunately. 92 from railtracks import Session 93 94 with Session(): 95 result = await _start(node, args=args, kwargs=kwargs) 96 return result 97 98 # if the context is not active then we know this is the top level request 99 if not is_context_active(): 100 result = await _start(node, args=args, kwargs=kwargs) 101 return result 102 103 # if the context is active then we can just run the node 104 result = await _run(node, args=args, kwargs=kwargs) 105 return result
Call a node from within a node inside the framework. This will return a coroutine that you can interact with in whatever way using async/await logic.
Usage:
# for sequential operation
result = await call(NodeA, "hello world", 42)
# for parallel operation
tasks = [call(NodeA, "hello world", i) for i in range(10)]
results = await asyncio.gather(*tasks)
Arguments:
- node: The node type you would like to create. This could be a function decorated with
@function_node, a function, or a Node instance. - *args: The arguments to pass to the node
- **kwargs: The keyword arguments to pass to the node
6async def broadcast(item: str): 7 """ 8 Streams the given message 9 10 This will trigger the broadcast_callback callback you have already provided. 11 12 Args: 13 item (str): The item you want to stream. 14 """ 15 publisher = get_publisher() 16 17 await publisher.publish(Streaming(node_id=get_parent_id(), streamed_object=item))
Streams the given message
This will trigger the broadcast_callback callback you have already provided.
Arguments:
- item (str): The item you want to stream.
27async def call_batch( 28 node: Callable[..., Node[_TOutput]] 29 | Callable[..., _TOutput] 30 | _AsyncNodeAttachedFunc[_P, _TOutput] 31 | _SyncNodeAttachedFunc[_P, _TOutput], 32 *iterables: Iterable[Any], 33 return_exceptions: bool = True, 34): 35 """ 36 Complete a node over multiple iterables, allowing for parallel execution. 37 38 Note the results will be returned in the order of the iterables, not the order of completion. 39 40 If one of the nodes returns an exception, the thrown exception will be included as a response. 41 42 Args: 43 node: The node type to create. 44 *iterables: The iterables to map the node over. 45 return_exceptions: If True, exceptions will be returned as part of the results. 46 If False, exceptions will be raised immediately, and you will lose access to the results. 47 Defaults to true. 48 49 Returns: 50 An iterable of results from the node. 51 52 Usage: 53 ```python 54 results = await batch(NodeA, ["hello world"] * 10) 55 for result in results: 56 handle(result) 57 ``` 58 """ 59 # this is big typing disaster but there is no way around it. Try if if you want to. 60 contracts = [call(node, *args) for args in zip(*iterables)] 61 62 results = await asyncio.gather(*contracts, return_exceptions=return_exceptions) 63 return results
Complete a node over multiple iterables, allowing for parallel execution.
Note the results will be returned in the order of the iterables, not the order of completion.
If one of the nodes returns an exception, the thrown exception will be included as a response.
Arguments:
- node: The node type to create.
- *iterables: The iterables to map the node over.
- return_exceptions: If True, exceptions will be returned as part of the results. If False, exceptions will be raised immediately, and you will lose access to the results. Defaults to true.
Returns:
An iterable of results from the node.
Usage:
results = await batch(NodeA, ["hello world"] * 10) for result in results: handle(result)
19class ExecutionInfo: 20 """ 21 A class that contains the full details of the state of a run at any given point in time. 22 23 The class is designed to be used as a snapshot of state that can be used to display the state of the run, or to 24 create a graphical representation of the system. 25 """ 26 27 def __init__( 28 self, 29 request_forest: RequestForest, 30 node_forest: NodeForest, 31 stamper: StampManager, 32 ): 33 self.request_forest = request_forest 34 self.node_forest = node_forest 35 self.stamper = stamper 36 37 @classmethod 38 def default(cls) -> ExecutionInfo: 39 """Creates a new "empty" instance of the ExecutionInfo class with the default values.""" 40 return cls.create_new() 41 42 @classmethod 43 def create_new( 44 cls, 45 ) -> ExecutionInfo: 46 """ 47 Creates a new empty instance of state variables with the provided executor configuration. 48 49 """ 50 request_heap = RequestForest() 51 node_heap = NodeForest() 52 stamper = StampManager() 53 54 return ExecutionInfo( 55 request_forest=request_heap, 56 node_forest=node_heap, 57 stamper=stamper, 58 ) 59 60 @property 61 def answer(self): 62 """Convenience method to access the answer of the run.""" 63 return self.request_forest.answer 64 65 @property 66 def all_stamps(self) -> List[Stamp]: 67 """Convenience method to access all the stamps of the run.""" 68 return self.stamper.all_stamps 69 70 @property 71 def name(self): 72 """ 73 Gets the name of the graph by pulling the name of the insertion request. It will raise a ValueError if the insertion 74 request is not present or there are multiple insertion requests. 75 """ 76 insertion_requests = self.insertion_requests 77 78 # The name is only defined for the length of 1. 79 # NOTE: Maybe we should send a warning once to user in other cases. 80 if len(insertion_requests) != 1: 81 return None 82 83 i_r = insertion_requests[0] 84 85 return self.node_forest.get_node_type(i_r.sink_id).name() 86 87 @property 88 def insertion_requests(self): 89 """A convenience method to access all the insertion requests of the run.""" 90 return self.request_forest.insertion_request 91 92 def _get_info(self, ids: List[str] | str | None = None) -> ExecutionInfo: 93 """ 94 Gets a subset of the current state based on the provided node ids. It will contain all the children of the provided node ids 95 96 Note: If no ids are provided, the full state is returned. 97 98 Args: 99 ids (List[str] | str | None): A list of node ids to filter the state by. If None, the full state is returned. 100 101 Returns: 102 ExecutionInfo: A new instance of ExecutionInfo containing only the children of the provided ids. 103 104 """ 105 if ids is None: 106 return self 107 else: 108 # firstly lets 109 if isinstance(ids, str): 110 ids = [ids] 111 112 # we need to quickly check to make sure these ids are valid 113 for identifier in ids: 114 if identifier not in self.request_forest: 115 raise ValueError( 116 f"Identifier '{identifier}' not found in the current state." 117 ) 118 119 new_node_forest, new_request_forest = create_sub_state_info( 120 self.node_forest.heap(), 121 self.request_forest.heap(), 122 ids, 123 ) 124 return ExecutionInfo( 125 node_forest=new_node_forest, 126 request_forest=new_request_forest, 127 stamper=self.stamper, 128 ) 129 130 def _to_graph(self) -> Tuple[List[Vertex], List[Edge]]: 131 """ 132 Converts the current state into its graph representation. 133 134 Returns: 135 List[Node]: An iterable of nodes in the graph. 136 List[Edge]: An iterable of edges in the graph. 137 """ 138 return self.node_forest.to_vertices(), self.request_forest.to_edges() 139 140 def graph_serialization(self) -> dict[str, Any]: 141 """ 142 Creates a string (JSON) representation of this info object designed to be used to construct a graph for this 143 info object. 144 145 Some important notes about its structure are outlined below: 146 - The `nodes` key contains a list of all the nodes in the graph, represented as `Vertex` objects. 147 - The `edges` key contains a list of all the edges in the graph, represented as `Edge` objects. 148 - The `stamps` key contains an ease of use list of all the stamps associated with the run, represented as `Stamp` objects. 149 150 - The "nodes" and "requests" key will be outlined with normal graph details like connections and identifiers in addition to a loose details object. 151 - However, both will carry an addition param called "stamp" which is a timestamp style object. 152 - They also will carry a "parent" param which is a recursive structure that allows you to traverse the graph in time. 153 154 155 ``` 156 """ 157 parent_nodes = [x.identifier for x in self.insertion_requests] 158 159 infos = [self._get_info(parent_node) for parent_node in parent_nodes] 160 161 runs = [] 162 163 for info, parent_node_id in zip(infos, parent_nodes): 164 insertion_requests = info.request_forest.insertion_request 165 166 assert len(insertion_requests) == 1 167 parent_request = insertion_requests[0] 168 169 all_parents = parent_request.get_all_parents() 170 171 start_time = all_parents[-1].stamp.time 172 173 assert len([x for x in all_parents if x.status == "Completed"]) <= 1 174 end_time = None 175 for req in all_parents: 176 if req.status in ["Completed", "Failed"]: 177 end_time = req.stamp.time 178 break 179 180 entry = { 181 "name": info.name, 182 "run_id": parent_node_id, 183 "nodes": info.node_forest.to_vertices(), 184 "status": parent_request.status, 185 "edges": info.request_forest.to_edges(), 186 "steps": _get_stamps_from_forests( 187 info.node_forest, info.request_forest 188 ), 189 "start_time": start_time, 190 "end_time": end_time, 191 } 192 runs.append(entry) 193 194 return json.loads( 195 json.dumps( 196 runs, 197 cls=RTJSONEncoder, 198 ) 199 )
A class that contains the full details of the state of a run at any given point in time.
The class is designed to be used as a snapshot of state that can be used to display the state of the run, or to create a graphical representation of the system.
37 @classmethod 38 def default(cls) -> ExecutionInfo: 39 """Creates a new "empty" instance of the ExecutionInfo class with the default values.""" 40 return cls.create_new()
Creates a new "empty" instance of the ExecutionInfo class with the default values.
42 @classmethod 43 def create_new( 44 cls, 45 ) -> ExecutionInfo: 46 """ 47 Creates a new empty instance of state variables with the provided executor configuration. 48 49 """ 50 request_heap = RequestForest() 51 node_heap = NodeForest() 52 stamper = StampManager() 53 54 return ExecutionInfo( 55 request_forest=request_heap, 56 node_forest=node_heap, 57 stamper=stamper, 58 )
Creates a new empty instance of state variables with the provided executor configuration.
60 @property 61 def answer(self): 62 """Convenience method to access the answer of the run.""" 63 return self.request_forest.answer
Convenience method to access the answer of the run.
65 @property 66 def all_stamps(self) -> List[Stamp]: 67 """Convenience method to access all the stamps of the run.""" 68 return self.stamper.all_stamps
Convenience method to access all the stamps of the run.
70 @property 71 def name(self): 72 """ 73 Gets the name of the graph by pulling the name of the insertion request. It will raise a ValueError if the insertion 74 request is not present or there are multiple insertion requests. 75 """ 76 insertion_requests = self.insertion_requests 77 78 # The name is only defined for the length of 1. 79 # NOTE: Maybe we should send a warning once to user in other cases. 80 if len(insertion_requests) != 1: 81 return None 82 83 i_r = insertion_requests[0] 84 85 return self.node_forest.get_node_type(i_r.sink_id).name()
Gets the name of the graph by pulling the name of the insertion request. It will raise a ValueError if the insertion request is not present or there are multiple insertion requests.
87 @property 88 def insertion_requests(self): 89 """A convenience method to access all the insertion requests of the run.""" 90 return self.request_forest.insertion_request
A convenience method to access all the insertion requests of the run.
140 def graph_serialization(self) -> dict[str, Any]: 141 """ 142 Creates a string (JSON) representation of this info object designed to be used to construct a graph for this 143 info object. 144 145 Some important notes about its structure are outlined below: 146 - The `nodes` key contains a list of all the nodes in the graph, represented as `Vertex` objects. 147 - The `edges` key contains a list of all the edges in the graph, represented as `Edge` objects. 148 - The `stamps` key contains an ease of use list of all the stamps associated with the run, represented as `Stamp` objects. 149 150 - The "nodes" and "requests" key will be outlined with normal graph details like connections and identifiers in addition to a loose details object. 151 - However, both will carry an addition param called "stamp" which is a timestamp style object. 152 - They also will carry a "parent" param which is a recursive structure that allows you to traverse the graph in time. 153 154 155 ``` 156 """ 157 parent_nodes = [x.identifier for x in self.insertion_requests] 158 159 infos = [self._get_info(parent_node) for parent_node in parent_nodes] 160 161 runs = [] 162 163 for info, parent_node_id in zip(infos, parent_nodes): 164 insertion_requests = info.request_forest.insertion_request 165 166 assert len(insertion_requests) == 1 167 parent_request = insertion_requests[0] 168 169 all_parents = parent_request.get_all_parents() 170 171 start_time = all_parents[-1].stamp.time 172 173 assert len([x for x in all_parents if x.status == "Completed"]) <= 1 174 end_time = None 175 for req in all_parents: 176 if req.status in ["Completed", "Failed"]: 177 end_time = req.stamp.time 178 break 179 180 entry = { 181 "name": info.name, 182 "run_id": parent_node_id, 183 "nodes": info.node_forest.to_vertices(), 184 "status": parent_request.status, 185 "edges": info.request_forest.to_edges(), 186 "steps": _get_stamps_from_forests( 187 info.node_forest, info.request_forest 188 ), 189 "start_time": start_time, 190 "end_time": end_time, 191 } 192 runs.append(entry) 193 194 return json.loads( 195 json.dumps( 196 runs, 197 cls=RTJSONEncoder, 198 ) 199 )
Creates a string (JSON) representation of this info object designed to be used to construct a graph for this info object.
Some important notes about its structure are outlined below:
- The `nodes` key contains a list of all the nodes in the graph, represented as `Vertex` objects.
- The `edges` key contains a list of all the edges in the graph, represented as `Edge` objects.
- The `stamps` key contains an ease of use list of all the stamps associated with the run, represented as `Stamp` objects.
- The "nodes" and "requests" key will be outlined with normal graph details like connections and identifiers in addition to a loose details object.
- However, both will carry an addition param called "stamp" which is a timestamp style object.
- They also will carry a "parent" param which is a recursive structure that allows you to traverse the graph in time.
```
10class ExecutorConfig: 11 def __init__( 12 self, 13 *, 14 timeout: float = 150.0, 15 end_on_error: bool = False, 16 logging_setting: AllowableLogLevels = "INFO", 17 log_file: str | os.PathLike | None = None, 18 broadcast_callback: ( 19 Callable[[str], None] | Callable[[str], Coroutine[None, None, None]] | None 20 ) = None, 21 prompt_injection: bool = True, 22 save_state: bool = True, 23 payload_callback: Callable[[dict[str, Any]], None] | None = None, 24 ): 25 """ 26 ExecutorConfig is special configuration object designed to allow customization of the executor in the RT system. 27 28 Args: 29 timeout (float): The maximum number of seconds to wait for a response to your top level request 30 end_on_error (bool): If true, the executor will stop execution when an exception is encountered. 31 logging_setting (AllowableLogLevels): The setting for the level of logging you would like to have. 32 log_file (str | os.PathLike | None): The file to which the logs will be written. If None, no file will be created. 33 broadcast_callback (Callable or Coroutine): A function or coroutine that will handle streaming messages. 34 prompt_injection (bool): If true, prompts can be injected with global context 35 save_state (bool): If true, the state of the executor will be saved to disk. 36 """ 37 self.timeout = timeout 38 self.end_on_error = end_on_error 39 self.logging_setting = logging_setting 40 self.subscriber = broadcast_callback 41 self.log_file = log_file 42 self.prompt_injection = prompt_injection 43 self.save_state = save_state 44 self.payload_callback = payload_callback 45 46 @property 47 def logging_setting(self) -> AllowableLogLevels: 48 return self._logging_setting 49 50 @logging_setting.setter 51 def logging_setting(self, value: AllowableLogLevels): 52 if value not in str_to_log_level: 53 raise ValueError( 54 f"logging_setting must be one of {str_to_log_level}, got {value}" 55 ) 56 self._logging_setting: AllowableLogLevels = value 57 58 def precedence_overwritten( 59 self, 60 *, 61 timeout: float | None = None, 62 end_on_error: bool | None = None, 63 logging_setting: AllowableLogLevels | None = None, 64 log_file: str | os.PathLike | None = None, 65 subscriber: ( 66 Callable[[str], None] | Callable[[str], Coroutine[None, None, None]] | None 67 ) = None, 68 prompt_injection: bool | None = None, 69 save_state: bool | None = None, 70 payload_callback: Callable[[dict[str, Any]], None] | None = None, 71 ): 72 """ 73 If any of the parameters are provided (not None), it will create a new update the current instance with the new values and return a deep copied reference to it. 74 """ 75 return ExecutorConfig( 76 timeout=timeout if timeout is not None else self.timeout, 77 end_on_error=end_on_error 78 if end_on_error is not None 79 else self.end_on_error, 80 logging_setting=logging_setting 81 if logging_setting is not None 82 else self.logging_setting, 83 log_file=log_file if log_file is not None else self.log_file, 84 broadcast_callback=subscriber 85 if subscriber is not None 86 else self.subscriber, 87 prompt_injection=prompt_injection 88 if prompt_injection is not None 89 else self.prompt_injection, 90 save_state=save_state if save_state is not None else self.save_state, 91 payload_callback=payload_callback 92 if payload_callback is not None 93 else self.payload_callback, 94 ) 95 96 def __repr__(self): 97 return ( 98 f"ExecutorConfig(timeout={self.timeout}, end_on_error={self.end_on_error}, " 99 f"logging_setting={self.logging_setting}, log_file={self.log_file}, " 100 f"prompt_injection={self.prompt_injection}, " 101 f"save_state={self.save_state}, payload_callback={self.payload_callback})" 102 )
11 def __init__( 12 self, 13 *, 14 timeout: float = 150.0, 15 end_on_error: bool = False, 16 logging_setting: AllowableLogLevels = "INFO", 17 log_file: str | os.PathLike | None = None, 18 broadcast_callback: ( 19 Callable[[str], None] | Callable[[str], Coroutine[None, None, None]] | None 20 ) = None, 21 prompt_injection: bool = True, 22 save_state: bool = True, 23 payload_callback: Callable[[dict[str, Any]], None] | None = None, 24 ): 25 """ 26 ExecutorConfig is special configuration object designed to allow customization of the executor in the RT system. 27 28 Args: 29 timeout (float): The maximum number of seconds to wait for a response to your top level request 30 end_on_error (bool): If true, the executor will stop execution when an exception is encountered. 31 logging_setting (AllowableLogLevels): The setting for the level of logging you would like to have. 32 log_file (str | os.PathLike | None): The file to which the logs will be written. If None, no file will be created. 33 broadcast_callback (Callable or Coroutine): A function or coroutine that will handle streaming messages. 34 prompt_injection (bool): If true, prompts can be injected with global context 35 save_state (bool): If true, the state of the executor will be saved to disk. 36 """ 37 self.timeout = timeout 38 self.end_on_error = end_on_error 39 self.logging_setting = logging_setting 40 self.subscriber = broadcast_callback 41 self.log_file = log_file 42 self.prompt_injection = prompt_injection 43 self.save_state = save_state 44 self.payload_callback = payload_callback
ExecutorConfig is special configuration object designed to allow customization of the executor in the RT system.
Arguments:
- timeout (float): The maximum number of seconds to wait for a response to your top level request
- end_on_error (bool): If true, the executor will stop execution when an exception is encountered.
- logging_setting (AllowableLogLevels): The setting for the level of logging you would like to have.
- log_file (str | os.PathLike | None): The file to which the logs will be written. If None, no file will be created.
- broadcast_callback (Callable or Coroutine): A function or coroutine that will handle streaming messages.
- prompt_injection (bool): If true, prompts can be injected with global context
- save_state (bool): If true, the state of the executor will be saved to disk.
58 def precedence_overwritten( 59 self, 60 *, 61 timeout: float | None = None, 62 end_on_error: bool | None = None, 63 logging_setting: AllowableLogLevels | None = None, 64 log_file: str | os.PathLike | None = None, 65 subscriber: ( 66 Callable[[str], None] | Callable[[str], Coroutine[None, None, None]] | None 67 ) = None, 68 prompt_injection: bool | None = None, 69 save_state: bool | None = None, 70 payload_callback: Callable[[dict[str, Any]], None] | None = None, 71 ): 72 """ 73 If any of the parameters are provided (not None), it will create a new update the current instance with the new values and return a deep copied reference to it. 74 """ 75 return ExecutorConfig( 76 timeout=timeout if timeout is not None else self.timeout, 77 end_on_error=end_on_error 78 if end_on_error is not None 79 else self.end_on_error, 80 logging_setting=logging_setting 81 if logging_setting is not None 82 else self.logging_setting, 83 log_file=log_file if log_file is not None else self.log_file, 84 broadcast_callback=subscriber 85 if subscriber is not None 86 else self.subscriber, 87 prompt_injection=prompt_injection 88 if prompt_injection is not None 89 else self.prompt_injection, 90 save_state=save_state if save_state is not None else self.save_state, 91 payload_callback=payload_callback 92 if payload_callback is not None 93 else self.payload_callback, 94 )
If any of the parameters are provided (not None), it will create a new update the current instance with the new values and return a deep copied reference to it.
360def set_config( 361 *, 362 timeout: float | None = None, 363 end_on_error: bool | None = None, 364 logging_setting: AllowableLogLevels | None = None, 365 log_file: str | os.PathLike | None = None, 366 broadcast_callback: ( 367 Callable[[str], None] | Callable[[str], Coroutine[None, None, None]] | None 368 ) = None, 369 prompt_injection: bool | None = None, 370 save_state: bool | None = None, 371): 372 """ 373 Sets the global configuration for the executor. This will be propagated to all new runners created after this call. 374 375 - If you call this function after the runner has been created, it will not affect the current runner. 376 - This function will only overwrite the values that are provided, leaving the rest unchanged. 377 378 379 """ 380 381 if is_context_active(): 382 warnings.warn( 383 "The executor config is being set after the runner has been created, this is not recommended" 384 ) 385 386 config = global_executor_config.get() 387 388 if logging_setting or log_file: 389 # default will be set at module import time, this is for overwrites 390 configure_module_logging(level=logging_setting, log_file=log_file) 391 392 new_config = config.precedence_overwritten( 393 timeout=timeout, 394 end_on_error=end_on_error, 395 logging_setting=logging_setting, 396 log_file=log_file, 397 subscriber=broadcast_callback, 398 prompt_injection=prompt_injection, 399 save_state=save_state, 400 ) 401 402 global_executor_config.set(new_config)
Sets the global configuration for the executor. This will be propagated to all new runners created after this call.
- If you call this function after the runner has been created, it will not affect the current runner.
- This function will only overwrite the values that are provided, leaving the rest unchanged.
173def function_node( 174 func: Callable[_P, Coroutine[None, None, _TOutput] | _TOutput] 175 | List[Callable[_P, Coroutine[None, None, _TOutput] | _TOutput]], 176 /, 177 *, 178 name: str | None = None, 179 manifest: ToolManifest | None = None, 180) -> ( 181 Callable[_P, Coroutine[None, None, _TOutput] | _TOutput] 182 | List[Callable[_P, Coroutine[None, None, _TOutput] | _TOutput]] 183 | None 184): 185 """ 186 Creates a new Node type from a function that can be used in `rt.call()`. 187 188 By default, it will parse the function's docstring and turn them into tool details and parameters. However, if 189 you provide custom ToolManifest it will override that logic. 190 191 WARNING: If you overriding tool parameters. It is on you to make sure they will work with your function. 192 193 NOTE: If you have already converted this function to a node this function will do nothing 194 195 Args: 196 func (Callable): The function to convert into a Node. 197 name (str, optional): Human-readable name for the node/tool. 198 manifest (ToolManifest, optional): The details you would like to override the tool with. 199 """ 200 201 # handle the case where a list of functions is provided 202 if isinstance(func, list): 203 return [function_node(f, name=name, manifest=manifest) for f in func] 204 205 # check if the function has already been converted to a node 206 if hasattr(func, "node_type"): 207 warnings.warn( 208 "The provided function has already been converted to a node.", 209 UserWarning, 210 ) 211 return func 212 213 # validate_function_parameters is separated out to allow for easier testing. 214 validate_function_parameters(func, manifest) 215 216 # assign the correct node class based on whether the function is async or sync 217 if asyncio.iscoroutinefunction(func): 218 node_class = AsyncDynamicFunctionNode 219 elif inspect.isfunction(func): 220 node_class = SyncDynamicFunctionNode 221 elif inspect.isbuiltin(func): 222 # builtin functions are written in C and do not have space for the addition of metadata like our node type. 223 # so instead we wrap them in a function that allows for the addition of the node type. 224 # this logic preserved details like the function name, docstring, and signature, but allows us to add the node type. 225 func = _function_preserving_metadata(func) 226 node_class = SyncDynamicFunctionNode 227 else: 228 raise NodeCreationError( 229 message=f"The provided function is not a valid coroutine or sync function it is {type(func)}.", 230 notes=[ 231 "You must provide a valid function or coroutine function to make a node.", 232 ], 233 ) 234 235 # build the node using the NodeBuilder 236 builder = NodeBuilder( 237 node_class, 238 name=name if name is not None else f"{func.__name__}", 239 ) 240 241 builder.setup_function_node( 242 func, 243 tool_details=manifest.description if manifest is not None else None, 244 tool_params=manifest.parameters if manifest is not None else None, 245 ) 246 247 completed_node_type = builder.build() 248 249 # there is some pretty scary logic here. 250 if issubclass(completed_node_type, AsyncDynamicFunctionNode): 251 setattr(func, "node_type", completed_node_type) 252 return func 253 elif issubclass(completed_node_type, SyncDynamicFunctionNode): 254 setattr(func, "node_type", completed_node_type) 255 return func 256 else: 257 raise NodeCreationError( 258 message="The provided function did not create a valid node type.", 259 notes=[ 260 "Please make a github issue with the details of what went wrong.", 261 ], 262 )
Creates a new Node type from a function that can be used in rt.call().
By default, it will parse the function's docstring and turn them into tool details and parameters. However, if you provide custom ToolManifest it will override that logic.
WARNING: If you overriding tool parameters. It is on you to make sure they will work with your function.
NOTE: If you have already converted this function to a node this function will do nothing
Arguments:
- func (Callable): The function to convert into a Node.
- name (str, optional): Human-readable name for the node/tool.
- manifest (ToolManifest, optional): The details you would like to override the tool with.
129def agent_node( 130 name: str | None = None, 131 *, 132 rag: RagConfig | None = None, 133 tool_nodes: Iterable[Type[Node] | Callable | RTFunction] | None = None, 134 output_schema: Type[_TBaseModel] | None = None, 135 llm: ModelBase[_TStream] | None = None, 136 max_tool_calls: int | None = None, 137 system_message: SystemMessage | str | None = None, 138 manifest: ToolManifest | None = None, 139): 140 """ 141 Dynamically creates an agent based on the provided parameters. 142 143 Args: 144 name (str | None): The name of the agent. If none the default will be used. 145 rag (RagConfig | None): If your agent is a rag agent put in the vector store it is connected to. 146 tool_nodes (set[Type[Node] | Callable | RTFunction] | None): If your agent is a LLM with access to tools, what does it have access to? 147 output_schema (Type[_TBaseModel] | None): If your agent should return a structured output, what is the output_schema? 148 llm (ModelBase): The LLM model to use. If None it will need to be passed in at instance time. 149 max_tool_calls (int | None): Maximum number of tool calls allowed (if it is a ToolCall Agent). 150 system_message (SystemMessage | str | None): System message for the agent. 151 manifest (ToolManifest | None): If you want to use this as a tool in other agents you can pass in a ToolManifest. 152 """ 153 unpacked_tool_nodes: set[Type[Node]] | None = None 154 if tool_nodes is not None: 155 unpacked_tool_nodes = set() 156 for node in tool_nodes: 157 if isinstance(node, FunctionType): 158 unpacked_tool_nodes.add(extract_node_from_function(node)) 159 else: 160 assert issubclass(node, Node), ( 161 f"Expected {node} to be a subclass of Node" 162 ) 163 unpacked_tool_nodes.add(node) 164 165 # See issue (___) this logic should be migrated soon. 166 if manifest is not None: 167 tool_details = manifest.description 168 tool_params = manifest.parameters 169 else: 170 tool_details = None 171 tool_params = None 172 173 if unpacked_tool_nodes is not None and len(unpacked_tool_nodes) > 0: 174 if output_schema is not None: 175 agent = structured_tool_call_llm( 176 tool_nodes=unpacked_tool_nodes, 177 output_schema=output_schema, 178 name=name, 179 llm=llm, 180 max_tool_calls=max_tool_calls, 181 system_message=system_message, 182 tool_details=tool_details, 183 tool_params=tool_params, 184 ) 185 else: 186 agent = tool_call_llm( 187 tool_nodes=unpacked_tool_nodes, 188 name=name, 189 llm=llm, 190 max_tool_calls=max_tool_calls, 191 system_message=system_message, 192 tool_details=tool_details, 193 tool_params=tool_params, 194 ) 195 else: 196 if output_schema is not None: 197 agent = structured_llm( 198 output_schema=output_schema, 199 name=name, 200 llm=llm, 201 system_message=system_message, 202 tool_details=tool_details, 203 tool_params=tool_params, 204 ) 205 else: 206 agent = terminal_llm( 207 name=name, 208 llm=llm, 209 system_message=system_message, 210 tool_details=tool_details, 211 tool_params=tool_params, 212 ) 213 214 if rag is not None: 215 216 def _update_message_history(node: LLMBase): 217 node.message_hist = update_context( 218 node.message_hist, vs=rag.vector_store, top_k=rag.top_k 219 ) 220 return 221 222 agent.add_pre_invoke(_update_message_history) 223 224 return agent
Dynamically creates an agent based on the provided parameters.
Arguments:
- name (str | None): The name of the agent. If none the default will be used.
- rag (RagConfig | None): If your agent is a rag agent put in the vector store it is connected to.
- tool_nodes (set[Type[Node] | Callable | RTFunction] | None): If your agent is a LLM with access to tools, what does it have access to?
- output_schema (Type[_TBaseModel] | None): If your agent should return a structured output, what is the output_schema?
- llm (ModelBase): The LLM model to use. If None it will need to be passed in at instance time.
- max_tool_calls (int | None): Maximum number of tool calls allowed (if it is a ToolCall Agent).
- system_message (SystemMessage | str | None): System message for the agent.
- manifest (ToolManifest | None): If you want to use this as a tool in other agents you can pass in a ToolManifest.
19class MCPStdioParams(StdioServerParameters): 20 """ 21 Configuration parameters for STDIO-based MCP server connections. 22 23 Extends the standard StdioServerParameters with a timeout field. 24 25 Attributes: 26 timeout: Maximum time to wait for operations (default: 30 seconds) 27 """ 28 29 timeout: timedelta = timedelta(seconds=30) 30 31 def as_stdio_params(self) -> StdioServerParameters: 32 """ 33 Convert to standard StdioServerParameters, excluding the timeout field. 34 35 Returns: 36 StdioServerParameters without the timeout attribute 37 """ 38 stdio_kwargs = self.dict(exclude={"timeout"}) 39 return StdioServerParameters(**stdio_kwargs)
Configuration parameters for STDIO-based MCP server connections.
Extends the standard StdioServerParameters with a timeout field.
Attributes:
- timeout: Maximum time to wait for operations (default: 30 seconds)
31 def as_stdio_params(self) -> StdioServerParameters: 32 """ 33 Convert to standard StdioServerParameters, excluding the timeout field. 34 35 Returns: 36 StdioServerParameters without the timeout attribute 37 """ 38 stdio_kwargs = self.dict(exclude={"timeout"}) 39 return StdioServerParameters(**stdio_kwargs)
Convert to standard StdioServerParameters, excluding the timeout field.
Returns:
StdioServerParameters without the timeout attribute
42class MCPHttpParams(BaseModel): 43 """ 44 Configuration parameters for HTTP-based MCP server connections. 45 46 Supports both SSE (Server-Sent Events) and streamable HTTP transports. 47 The transport type is automatically determined based on the URL. 48 49 Attributes: 50 url: The MCP server URL (use /sse suffix for SSE transport) 51 headers: Optional HTTP headers for authentication 52 timeout: Connection timeout (default: 30 seconds) 53 sse_read_timeout: SSE read timeout (default: 5 minutes) 54 terminate_on_close: Whether to terminate connection on close (default: True) 55 """ 56 57 url: str 58 headers: dict[str, Any] | None = None 59 timeout: timedelta = timedelta(seconds=30) 60 sse_read_timeout: timedelta = timedelta(seconds=60 * 5) 61 terminate_on_close: bool = True
Configuration parameters for HTTP-based MCP server connections.
Supports both SSE (Server-Sent Events) and streamable HTTP transports. The transport type is automatically determined based on the URL.
Attributes:
- url: The MCP server URL (use /sse suffix for SSE transport)
- headers: Optional HTTP headers for authentication
- timeout: Connection timeout (default: 30 seconds)
- sse_read_timeout: SSE read timeout (default: 5 minutes)
- terminate_on_close: Whether to terminate connection on close (default: True)
8def connect_mcp( 9 config: MCPStdioParams | MCPHttpParams, 10 client_session: ClientSession | None = None, 11 setup_timeout: float = 30, 12) -> MCPServer: 13 """ 14 Connect to an MCP server and return a server instance with available tools. 15 16 This is the primary entry point for using MCP servers in Railtracks. 17 The server will connect in the background, discover available tools, 18 and convert them to Railtracks Node classes. 19 20 The connection remains active until explicitly closed or the context exits. 21 22 Usage Examples: 23 # STDIO connection (local MCP server) 24 config = rt.MCPStdioParams( 25 command="uvx", 26 args=["mcp-server-time"] 27 ) 28 server = rt.connect_mcp(config) 29 30 # HTTP connection (remote MCP server) 31 config = rt.MCPHttpParams( 32 url="https://mcp.example.com/sse", 33 headers={"Authorization": "Bearer token"} 34 ) 35 server = rt.connect_mcp(config) 36 37 # Context manager (recommended) 38 with rt.connect_mcp(config) as server: 39 tools = server.tools 40 # Use tools... 41 # Automatically closed 42 43 # Access tools 44 for tool in server.tools: 45 print(f"Tool: {tool.name()}") 46 print(f"Description: {tool.tool_info().description}") 47 48 Args: 49 config: Server configuration: 50 - MCPStdioParams: For local servers via stdin/stdout 51 - MCPHttpParams: For remote servers via HTTP/SSE 52 client_session: Optional pre-configured ClientSession for advanced use cases. 53 If not provided, a new session will be created automatically. 54 setup_timeout: Maximum seconds to wait for connection (default: 30). 55 Increase for slow servers or complex authentication flows. 56 57 Returns: 58 MCPServer: Connected server instance with: 59 - tools: List of Node classes representing MCP tools 60 - close(): Method to explicitly close the connection 61 - Context manager support for automatic cleanup 62 63 Raises: 64 FileNotFoundError: If STDIO command not found. Verify the command is: 65 - Installed and in your PATH 66 - Spelled correctly (check for typos) 67 - Executable (check permissions on Unix) 68 ConnectionError: If connection to server fails. Check: 69 - Server URL is correct and accessible 70 - Network connectivity and firewall settings 71 - Authentication credentials are valid 72 - Server is running and accepting connections 73 TimeoutError: If connection exceeds setup_timeout. Try: 74 - Increasing setup_timeout parameter 75 - Checking server performance/load 76 - Verifying server is responding 77 RuntimeError: For other setup failures (e.g., protocol errors, config issues) 78 79 Note: 80 - The connection runs in a background thread for sync/async bridging 81 - Tools are cached after first retrieval for performance 82 - Always close() the server when done or use context manager 83 - Jupyter compatibility patches are applied automatically 84 """ 85 # Apply Jupyter compatibility patches if needed 86 apply_patches() 87 88 return MCPServer( 89 config=config, client_session=client_session, setup_timeout=setup_timeout 90 )
Connect to an MCP server and return a server instance with available tools.
This is the primary entry point for using MCP servers in Railtracks. The server will connect in the background, discover available tools, and convert them to Railtracks Node classes.
The connection remains active until explicitly closed or the context exits.
Usage Examples:
STDIO connection (local MCP server)
config = rt.MCPStdioParams( command="uvx", args=["mcp-server-time"] ) server = rt.connect_mcp(config)
HTTP connection (remote MCP server)
config = rt.MCPHttpParams( url="https://mcp.example.com/sse", headers={"Authorization": "Bearer token"} ) server = rt.connect_mcp(config)
Context manager (recommended)
with rt.connect_mcp(config) as server: tools = server.tools # Use tools...
Automatically closed
Access tools
for tool in server.tools: print(f"Tool: {tool.name()}") print(f"Description: {tool.tool_info().description}")
Arguments:
- config: Server configuration:
- MCPStdioParams: For local servers via stdin/stdout
- MCPHttpParams: For remote servers via HTTP/SSE
- client_session: Optional pre-configured ClientSession for advanced use cases. If not provided, a new session will be created automatically.
- setup_timeout: Maximum seconds to wait for connection (default: 30). Increase for slow servers or complex authentication flows.
Returns:
MCPServer: Connected server instance with: - tools: List of Node classes representing MCP tools - close(): Method to explicitly close the connection - Context manager support for automatic cleanup
Raises:
- FileNotFoundError: If STDIO command not found. Verify the command is:
- Installed and in your PATH
- Spelled correctly (check for typos)
- Executable (check permissions on Unix)
- ConnectionError: If connection to server fails. Check:
- Server URL is correct and accessible
- Network connectivity and firewall settings
- Authentication credentials are valid
- Server is running and accepting connections
- TimeoutError: If connection exceeds setup_timeout. Try:
- Increasing setup_timeout parameter
- Checking server performance/load
- Verifying server is responding
- RuntimeError: For other setup failures (e.g., protocol errors, config issues)
Note:
- The connection runs in a background thread for sync/async bridging
- Tools are cached after first retrieval for performance
- Always close() the server when done or use context manager
- Jupyter compatibility patches are applied automatically
86def create_mcp_server( 87 nodes: List[Node | RTFunction], 88 server_name: str = "MCP Server", 89 fastmcp: FastMCP | None = None, 90): 91 """ 92 Create a FastMCP server that can be used to run nodes as MCP tools. 93 94 Args: 95 nodes: List of Node classes to be registered as tools with the MCP server. 96 server_name: Name of the MCP server instance. 97 fastmcp: Optional FastMCP instance to use instead of creating a new one. 98 99 Returns: 100 A FastMCP server instance. 101 """ 102 if fastmcp is not None: 103 if not isinstance(fastmcp, FastMCP): 104 raise ValueError("Provided fastmcp must be an instance of FastMCP.") 105 mcp = fastmcp 106 else: 107 mcp = FastMCP(server_name) 108 109 for node in [n if not hasattr(n, "node_type") else n.node_type for n in nodes]: 110 node_info = node.tool_info() 111 func = _create_tool_function(node, node_info) 112 113 mcp._tool_manager._tools[node_info.name] = MCPTool( 114 fn=func, 115 name=node_info.name, 116 description=node_info.detail, 117 parameters=( 118 _parameters_to_json_schema(node_info.parameters) 119 if node_info.parameters is not None 120 else {} 121 ), 122 fn_metadata=func_metadata(func, []), 123 is_async=True, 124 context_kwarg=None, 125 annotations=None, 126 ) # Register the node as a tool 127 128 return mcp
Create a FastMCP server that can be used to run nodes as MCP tools.
Arguments:
- nodes: List of Node classes to be registered as tools with the MCP server.
- server_name: Name of the MCP server instance.
- fastmcp: Optional FastMCP instance to use instead of creating a new one.
Returns:
A FastMCP server instance.
7class ToolManifest: 8 """ 9 Creates a manifest for a tool, which includes its description and parameters. 10 11 Args: 12 description (str): A description of the tool. 13 parameters (Iterable[Parameter] | None): An iterable of parameters for the tool. If None, there are no paramerters. 14 """ 15 16 def __init__( 17 self, 18 description: str, 19 parameters: Iterable[Parameter] | None = None, 20 ): 21 self.description = description 22 self.parameters: List[Parameter] = ( 23 list(parameters) if parameters is not None else [] 24 )
Creates a manifest for a tool, which includes its description and parameters.
Arguments:
- description (str): A description of the tool.
- parameters (Iterable[Parameter] | None): An iterable of parameters for the tool. If None, there are no paramerters.
428def session_id(): 429 """ 430 Gets the current session ID if it exists, otherwise returns None. 431 """ 432 try: 433 return get_session_id() 434 except ContextError: 435 return None
Gets the current session ID if it exists, otherwise returns None.
8class RagConfig: 9 """ 10 Configuration object for Retrieval-Augmented Generation (RAG). 11 """ 12 13 def __init__(self, vector_store: VectorStore, top_k: int = 3) -> None: 14 self.vector_store = vector_store 15 self.top_k = top_k
Configuration object for Retrieval-Augmented Generation (RAG).
25class Flow(Generic[_P, _TOutput]): 26 def __init__( 27 self, 28 name: str, 29 entry_point: ( 30 Callable[_P, Node[_TOutput]] 31 | RTSyncFunction[_P, _TOutput] 32 | RTAsyncFunction[_P, _TOutput] 33 ), 34 *, 35 context: dict[str, Any] | None = None, 36 timeout: float | None = None, 37 end_on_error: bool | None = None, 38 logging_setting: AllowableLogLevels | None = None, 39 log_file: str | os.PathLike | None = None, 40 broadcast_callback: ( 41 Callable[[str], None] | Callable[[str], Coroutine[None, None, None]] | None 42 ) = None, 43 prompt_injection: bool | None = None, 44 save_state: bool | None = None, 45 payload_callback: Callable[[dict[str, Any]], None] | None = None, 46 ) -> None: 47 self.entry_point: Callable[_P, Node[_TOutput]] 48 49 if hasattr(entry_point, "node_type"): 50 self.entry_point = entry_point.node_type 51 else: 52 self.entry_point = entry_point 53 54 self.name = name 55 self._context: dict[str, Any] = context or {} 56 self._timeout = timeout 57 self._end_on_error = end_on_error 58 self._logging_setting = logging_setting 59 self._log_file = log_file 60 self._broadcast_callback = broadcast_callback 61 self._prompt_injection = prompt_injection 62 self._save_state = save_state 63 self._payload_callback = payload_callback 64 65 def update_context(self, context: dict[str, Any]) -> Flow[_P, _TOutput]: 66 """ 67 Creates a new Flow with the updated context. Note this will include the previous context values. 68 """ 69 new_obj = deepcopy(self) 70 new_obj._context.update(context) 71 return new_obj 72 73 async def ainvoke(self, *args: _P.args, **kwargs: _P.kwargs) -> _TOutput: 74 with Session( 75 context=deepcopy(self._context), 76 flow_name=self.name, 77 flow_id=self.equality_hash(), 78 name=None, 79 timeout=self._timeout, 80 end_on_error=self._end_on_error, 81 logging_setting=self._logging_setting, 82 log_file=self._log_file, 83 broadcast_callback=self._broadcast_callback, 84 prompt_injection=self._prompt_injection, 85 save_state=self._save_state, 86 payload_callback=self._payload_callback, 87 ): 88 result = await call(self.entry_point, *args, **kwargs) 89 90 return result 91 92 def invoke(self, *args: _P.args, **kwargs: _P.kwargs) -> _TOutput: 93 try: 94 return asyncio.run(self.ainvoke(*args, **kwargs)) 95 96 except RuntimeError: 97 raise RuntimeError( 98 "Cannot invoke flow synchronously within an active event loop. Use 'ainvoke' instead." 99 ) 100 101 def equality_hash(self) -> str: 102 """ 103 Generates a hash based on the flow's configuration for equality checks. 104 """ 105 config_string = json.dumps(self._get_hash_content(), sort_keys=True) 106 return hashlib.sha256(config_string.encode()).hexdigest() 107 108 def _get_hash_content(self) -> dict: 109 return { 110 "name": self.name, 111 }
Abstract base class for generic types.
A generic type is typically declared by inheriting from this class parameterized with one or more type variables. For example, a generic mapping type might be defined as::
class Mapping(Generic[KT, VT]): def __getitem__(self, key: KT) -> VT: ... # Etc.
This class can then be used as follows::
def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT: try: return mapping[key] except KeyError: return default
26 def __init__( 27 self, 28 name: str, 29 entry_point: ( 30 Callable[_P, Node[_TOutput]] 31 | RTSyncFunction[_P, _TOutput] 32 | RTAsyncFunction[_P, _TOutput] 33 ), 34 *, 35 context: dict[str, Any] | None = None, 36 timeout: float | None = None, 37 end_on_error: bool | None = None, 38 logging_setting: AllowableLogLevels | None = None, 39 log_file: str | os.PathLike | None = None, 40 broadcast_callback: ( 41 Callable[[str], None] | Callable[[str], Coroutine[None, None, None]] | None 42 ) = None, 43 prompt_injection: bool | None = None, 44 save_state: bool | None = None, 45 payload_callback: Callable[[dict[str, Any]], None] | None = None, 46 ) -> None: 47 self.entry_point: Callable[_P, Node[_TOutput]] 48 49 if hasattr(entry_point, "node_type"): 50 self.entry_point = entry_point.node_type 51 else: 52 self.entry_point = entry_point 53 54 self.name = name 55 self._context: dict[str, Any] = context or {} 56 self._timeout = timeout 57 self._end_on_error = end_on_error 58 self._logging_setting = logging_setting 59 self._log_file = log_file 60 self._broadcast_callback = broadcast_callback 61 self._prompt_injection = prompt_injection 62 self._save_state = save_state 63 self._payload_callback = payload_callback
65 def update_context(self, context: dict[str, Any]) -> Flow[_P, _TOutput]: 66 """ 67 Creates a new Flow with the updated context. Note this will include the previous context values. 68 """ 69 new_obj = deepcopy(self) 70 new_obj._context.update(context) 71 return new_obj
Creates a new Flow with the updated context. Note this will include the previous context values.
73 async def ainvoke(self, *args: _P.args, **kwargs: _P.kwargs) -> _TOutput: 74 with Session( 75 context=deepcopy(self._context), 76 flow_name=self.name, 77 flow_id=self.equality_hash(), 78 name=None, 79 timeout=self._timeout, 80 end_on_error=self._end_on_error, 81 logging_setting=self._logging_setting, 82 log_file=self._log_file, 83 broadcast_callback=self._broadcast_callback, 84 prompt_injection=self._prompt_injection, 85 save_state=self._save_state, 86 payload_callback=self._payload_callback, 87 ): 88 result = await call(self.entry_point, *args, **kwargs) 89 90 return result
101 def equality_hash(self) -> str: 102 """ 103 Generates a hash based on the flow's configuration for equality checks. 104 """ 105 config_string = json.dumps(self._get_hash_content(), sort_keys=True) 106 return hashlib.sha256(config_string.encode()).hexdigest()
Generates a hash based on the flow's configuration for equality checks.
292def enable_logging( 293 level: AllowableLogLevels = "INFO", 294 log_file: str | os.PathLike | None = None, 295) -> None: 296 """ 297 Opt-in helper to enable Railtracks logging. Call this explicitly from your 298 application entry point (CLI, main.py, server startup); the library never 299 calls it automatically. 300 301 Uses the given level and log_file; when None, reads RT_LOG_LEVEL and 302 RT_LOG_FILE from the environment. Sets up console output (and optional file) 303 with a ThreadAwareFilter for per-thread level control. 304 305 Args: 306 level: Logging level (default "INFO"). Overridden by RT_LOG_LEVEL when None. 307 log_file: Optional path for a log file. Overridden by RT_LOG_FILE when None. 308 """ 309 initialize_module_logging(level=level, log_file=log_file)
Opt-in helper to enable Railtracks logging. Call this explicitly from your application entry point (CLI, main.py, server startup); the library never calls it automatically.
Uses the given level and log_file; when None, reads RT_LOG_LEVEL and RT_LOG_FILE from the environment. Sets up console output (and optional file) with a ThreadAwareFilter for per-thread level control.
Arguments:
- level: Logging level (default "INFO"). Overridden by RT_LOG_LEVEL when None.
- log_file: Optional path for a log file. Overridden by RT_LOG_FILE when None.