diff --git a/README.md b/README.md index 12d4910f..df82def2 100644 --- a/README.md +++ b/README.md @@ -352,6 +352,7 @@ AnthropicBedrock( aws_secret_key='...', aws_access_key='...', aws_session_token='...', + api_key='...', # defaults to AWS_BEARER_TOKEN_BEDROCK envvar ) ``` diff --git a/src/anthropic/lib/bedrock/_auth.py b/src/anthropic/lib/bedrock/_auth.py index 0a8b2109..f0520543 100644 --- a/src/anthropic/lib/bedrock/_auth.py +++ b/src/anthropic/lib/bedrock/_auth.py @@ -41,7 +41,13 @@ def get_auth_headers( region: str | None, profile: str | None, data: str | None, + aws_bearer_token_bedrock: str | None = None, ) -> dict[str, str]: + if aws_bearer_token_bedrock is not None: + headers = headers.copy() + headers["Authorization"] = f"Bearer {aws_bearer_token_bedrock}" + return dict(headers) + from botocore.auth import SigV4Auth from botocore.awsrequest import AWSRequest diff --git a/src/anthropic/lib/bedrock/_client.py b/src/anthropic/lib/bedrock/_client.py index 013d2702..443a135e 100644 --- a/src/anthropic/lib/bedrock/_client.py +++ b/src/anthropic/lib/bedrock/_client.py @@ -156,6 +156,7 @@ def __init__( # outlining your use-case to help us decide if it should be # part of our public interface in the future. _strict_response_validation: bool = False, + api_key: str | None = None, ) -> None: self.aws_secret_key = aws_secret_key @@ -166,6 +167,11 @@ def __init__( self.aws_session_token = aws_session_token + if api_key is None: + api_key = os.environ.get("AWS_BEARER_TOKEN_BEDROCK") + + self.api_key = api_key + if base_url is None: base_url = os.environ.get("ANTHROPIC_BEDROCK_BASE_URL") if base_url is None: @@ -210,6 +216,7 @@ def _prepare_request(self, request: httpx.Request) -> None: region=self.aws_region or "us-east-1", profile=self.aws_profile, data=data, + aws_bearer_token_bedrock=self.api_key, ) request.headers.update(headers) @@ -298,6 +305,7 @@ def __init__( # outlining your use-case to help us decide if it should be # part of our public interface in the future. _strict_response_validation: bool = False, + api_key: str | None = None, ) -> None: self.aws_secret_key = aws_secret_key @@ -308,6 +316,11 @@ def __init__( self.aws_session_token = aws_session_token + if api_key is None: + api_key = os.environ.get("AWS_BEARER_TOKEN_BEDROCK") + + self.api_key = api_key + if base_url is None: base_url = os.environ.get("ANTHROPIC_BEDROCK_BASE_URL") if base_url is None: @@ -352,6 +365,7 @@ async def _prepare_request(self, request: httpx.Request) -> None: region=self.aws_region or "us-east-1", profile=self.aws_profile, data=data, + aws_bearer_token_bedrock=self.api_key, ) request.headers.update(headers) diff --git a/tests/lib/test_bedrock.py b/tests/lib/test_bedrock.py index fe62da43..210b499a 100644 --- a/tests/lib/test_bedrock.py +++ b/tests/lib/test_bedrock.py @@ -195,3 +195,51 @@ def test_region_infer_from_specified_profile( client = AnthropicBedrock() assert client.aws_region == next(profile for profile in profiles if profile["name"] == aws_profile)["region"] + + +@pytest.mark.respx() +def test_bearer_token_client_args(respx_mock: MockRouter) -> None: + respx_mock.post(re.compile(r"https://bedrock-runtime\.us-east-1\.amazonaws\.com/model/.*/invoke")).mock( + return_value=httpx.Response(200, json={"foo": "bar"}) + ) + + client = AnthropicBedrock(aws_region="us-east-1", api_key="test-bearer-token-from-args") + client.messages.create( + max_tokens=1024, + messages=[ + { + "role": "user", + "content": "Say hello there!", + } + ], + model="anthropic.claude-3-5-sonnet-20241022-v2:0", + ) + calls = cast("list[MockRequestCall]", respx_mock.calls) + assert len(calls) == 1 + auth_header = calls[0].request.headers.get("Authorization") + assert auth_header == "Bearer test-bearer-token-from-args" + + +@pytest.mark.respx() +def test_bearer_token_env(monkeypatch: pytest.MonkeyPatch, respx_mock: MockRouter) -> None: + respx_mock.post(re.compile(r"https://bedrock-runtime\.us-east-1\.amazonaws\.com/model/.*/invoke")).mock( + return_value=httpx.Response(200, json={"foo": "bar"}) + ) + + monkeypatch.setenv("AWS_BEARER_TOKEN_BEDROCK", "test-bearer-token-from-env") + + client = AnthropicBedrock(aws_region="us-east-1") + client.messages.create( + max_tokens=1024, + messages=[ + { + "role": "user", + "content": "Say hello there!", + } + ], + model="anthropic.claude-3-5-sonnet-20241022-v2:0", + ) + calls = cast("list[MockRequestCall]", respx_mock.calls) + assert len(calls) == 1 + auth_header = calls[0].request.headers.get("Authorization") + assert auth_header == "Bearer test-bearer-token-from-env"