mirror of
https://github.com/hwchase17/langchain
synced 2024-11-18 09:25:54 +00:00
dc81dba6cf
Description: This pull request addresses two key improvements to the langchain repository: **Fix for Crash in Flight Search Interface**: Previously, the code would crash when encountering a failure scenario in the flight ticket search interface. This PR resolves this issue by implementing a fix to handle such scenarios gracefully. Now, the code handles failures in the flight search interface without crashing, ensuring smoother operation. **Documentation Update for Amadeus Toolkit**: Prior to this update, examples provided in the documentation for the Amadeus Toolkit were unable to run correctly due to outdated information. This PR includes an update to the documentation, ensuring that all examples can now be executed successfully. With this update, users can effectively utilize the Amadeus Toolkit with accurate and functioning examples. These changes aim to enhance the reliability and usability of the langchain repository by addressing issues related to error handling and ensuring that documentation remains up-to-date and actionable. Issue: https://github.com/langchain-ai/langchain/issues/17375 Twitter Handle: SingletonYxx
154 lines
5.6 KiB
Python
154 lines
5.6 KiB
Python
import logging
|
|
from datetime import datetime as dt
|
|
from typing import Dict, Optional, Type
|
|
|
|
from langchain_core.callbacks import CallbackManagerForToolRun
|
|
from langchain_core.pydantic_v1 import BaseModel, Field
|
|
|
|
from langchain_community.tools.amadeus.base import AmadeusBaseTool
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class FlightSearchSchema(BaseModel):
|
|
"""Schema for the AmadeusFlightSearch tool."""
|
|
|
|
originLocationCode: str = Field(
|
|
description=(
|
|
" The three letter International Air Transport "
|
|
" Association (IATA) Location Identifier for the "
|
|
" search's origin airport. "
|
|
)
|
|
)
|
|
destinationLocationCode: str = Field(
|
|
description=(
|
|
" The three letter International Air Transport "
|
|
" Association (IATA) Location Identifier for the "
|
|
" search's destination airport. "
|
|
)
|
|
)
|
|
departureDateTimeEarliest: str = Field(
|
|
description=(
|
|
" The earliest departure datetime from the origin airport "
|
|
" for the flight search in the following format: "
|
|
' "YYYY-MM-DDTHH:MM:SS", where "T" separates the date and time '
|
|
' components. For example: "2023-06-09T10:30:00" represents '
|
|
" June 9th, 2023, at 10:30 AM. "
|
|
)
|
|
)
|
|
departureDateTimeLatest: str = Field(
|
|
description=(
|
|
" The latest departure datetime from the origin airport "
|
|
" for the flight search in the following format: "
|
|
' "YYYY-MM-DDTHH:MM:SS", where "T" separates the date and time '
|
|
' components. For example: "2023-06-09T10:30:00" represents '
|
|
" June 9th, 2023, at 10:30 AM. "
|
|
)
|
|
)
|
|
page_number: int = Field(
|
|
default=1,
|
|
description="The specific page number of flight results to retrieve",
|
|
)
|
|
|
|
|
|
class AmadeusFlightSearch(AmadeusBaseTool):
|
|
"""Tool for searching for a single flight between two airports."""
|
|
|
|
name: str = "single_flight_search"
|
|
description: str = (
|
|
" Use this tool to search for a single flight between the origin and "
|
|
" destination airports at a departure between an earliest and "
|
|
" latest datetime. "
|
|
)
|
|
args_schema: Type[FlightSearchSchema] = FlightSearchSchema
|
|
|
|
def _run(
|
|
self,
|
|
originLocationCode: str,
|
|
destinationLocationCode: str,
|
|
departureDateTimeEarliest: str,
|
|
departureDateTimeLatest: str,
|
|
page_number: int = 1,
|
|
run_manager: Optional[CallbackManagerForToolRun] = None,
|
|
) -> list:
|
|
try:
|
|
from amadeus import ResponseError
|
|
except ImportError as e:
|
|
raise ImportError(
|
|
"Unable to import amadeus, please install with `pip install amadeus`."
|
|
) from e
|
|
|
|
RESULTS_PER_PAGE = 10
|
|
|
|
# Authenticate and retrieve a client
|
|
client = self.client
|
|
|
|
# Check that earliest and latest dates are in the same day
|
|
earliestDeparture = dt.strptime(departureDateTimeEarliest, "%Y-%m-%dT%H:%M:%S")
|
|
latestDeparture = dt.strptime(departureDateTimeLatest, "%Y-%m-%dT%H:%M:%S")
|
|
|
|
if earliestDeparture.date() != latestDeparture.date():
|
|
logger.error(
|
|
" Error: Earliest and latest departure dates need to be the "
|
|
" same date. If you're trying to search for round-trip "
|
|
" flights, call this function for the outbound flight first, "
|
|
" and then call again for the return flight. "
|
|
)
|
|
return [None]
|
|
|
|
# Collect all results from the Amadeus Flight Offers Search API
|
|
response = None
|
|
try:
|
|
response = client.shopping.flight_offers_search.get(
|
|
originLocationCode=originLocationCode,
|
|
destinationLocationCode=destinationLocationCode,
|
|
departureDate=latestDeparture.strftime("%Y-%m-%d"),
|
|
adults=1,
|
|
)
|
|
except ResponseError as error:
|
|
print(error) # noqa: T201
|
|
|
|
# Generate output dictionary
|
|
output = []
|
|
if response is not None:
|
|
for offer in response.data:
|
|
itinerary: Dict = {}
|
|
itinerary["price"] = {}
|
|
itinerary["price"]["total"] = offer["price"]["total"]
|
|
currency = offer["price"]["currency"]
|
|
currency = response.result["dictionaries"]["currencies"][currency]
|
|
itinerary["price"]["currency"] = {}
|
|
itinerary["price"]["currency"] = currency
|
|
|
|
segments = []
|
|
for segment in offer["itineraries"][0]["segments"]:
|
|
flight = {}
|
|
flight["departure"] = segment["departure"]
|
|
flight["arrival"] = segment["arrival"]
|
|
flight["flightNumber"] = segment["number"]
|
|
carrier = segment["carrierCode"]
|
|
carrier = response.result["dictionaries"]["carriers"][carrier]
|
|
flight["carrier"] = carrier
|
|
|
|
segments.append(flight)
|
|
|
|
itinerary["segments"] = []
|
|
itinerary["segments"] = segments
|
|
|
|
output.append(itinerary)
|
|
|
|
# Filter out flights after latest departure time
|
|
for index, offer in enumerate(output):
|
|
offerDeparture = dt.strptime(
|
|
offer["segments"][0]["departure"]["at"], "%Y-%m-%dT%H:%M:%S"
|
|
)
|
|
|
|
if offerDeparture > latestDeparture:
|
|
output.pop(index)
|
|
|
|
# Return the paginated results
|
|
startIndex = (page_number - 1) * RESULTS_PER_PAGE
|
|
endIndex = startIndex + RESULTS_PER_PAGE
|
|
|
|
return output[startIndex:endIndex]
|