Browse Source

make sure a bad api key is unauthorized

main
John-Mark Gurney 4 years ago
parent
commit
7260db1d6b
1 changed files with 7 additions and 3 deletions
  1. +7
    -3
      bitelab/__init__.py

+ 7
- 3
bitelab/__init__.py View File

@@ -1,7 +1,7 @@
from typing import Optional from typing import Optional
from functools import lru_cache, wraps from functools import lru_cache, wraps


from fastapi import APIRouter, Depends, FastAPI, Request
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request
from fastapi.security import OAuth2PasswordBearer from fastapi.security import OAuth2PasswordBearer
from httpx import AsyncClient, Auth from httpx import AsyncClient, Auth
from starlette.status import HTTP_200_OK, HTTP_404_NOT_FOUND, HTTP_401_UNAUTHORIZED from starlette.status import HTTP_200_OK, HTTP_404_NOT_FOUND, HTTP_401_UNAUTHORIZED
@@ -11,6 +11,7 @@ from . import data


import asyncio import asyncio
import gc import gc
import orm
import socket import socket
import sys import sys
import tempfile import tempfile
@@ -103,9 +104,9 @@ oauth2_scheme = OAuth2PasswordBearer(tokenUrl='/nonexistent')
async def lookup_user(token: str = Depends(oauth2_scheme), data: data.DataWrapper = Depends(get_data)): async def lookup_user(token: str = Depends(oauth2_scheme), data: data.DataWrapper = Depends(get_data)):
try: try:
return (await data.APIKey.objects.get(key=token)).user return (await data.APIKey.objects.get(key=token)).user
except KeyError:
except orm.exceptions.NoMatch:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
status_code=HTTP_401_UNAUTHORIZED,
detail='Invalid authentication credentials', detail='Invalid authentication credentials',
headers={'WWW-Authenticate': 'Bearer'}, headers={'WWW-Authenticate': 'Bearer'},
) )
@@ -213,6 +214,9 @@ class TestBiteLab(unittest.IsolatedAsyncioTestCase):
res = await self.client.get('/board_info') res = await self.client.get('/board_info')
self.assertEqual(res.status_code, HTTP_401_UNAUTHORIZED) self.assertEqual(res.status_code, HTTP_401_UNAUTHORIZED)


res = await self.client.get('/board_classes', auth=BiteAuth('badapikey'))
self.assertEqual(res.status_code, HTTP_401_UNAUTHORIZED)

async def test_classes(self): async def test_classes(self):
res = await self.client.get('/board_classes', auth=BiteAuth('thisisanapikey')) res = await self.client.get('/board_classes', auth=BiteAuth('thisisanapikey'))
self.assertEqual(res.status_code, HTTP_200_OK) self.assertEqual(res.status_code, HTTP_200_OK)


Loading…
Cancel
Save