ansible/roles/ask/files/sanction-client.py

195 lines
7.1 KiB
Python

# vim: set ts=4 sw=)
""" OAuth 2.0 client librar
"""
from json import loads
from datetime import datetime, timedelta
from time import mktime
try:
from urllib import urlencode
from urllib2 import Request, urlopen
from urlparse import urlsplit, urlunsplit, parse_qsl
# monkeypatch httpmessage
from httplib import HTTPMessage
def get_charset(self):
try:
data = filter(lambda s: 'Content-Type' in s, self.headers)[0]
if 'charset' in data:
cs = data[data.index(';') + 1:-2].split('=')[1].lower()
return cs
except IndexError:
pass
return 'utf-8'
HTTPMessage.get_content_charset = get_charset
except ImportError:
from urllib.parse import urlencode, urlsplit, urlunsplit, parse_qsl
from urllib.request import Request, urlopen
class Client(object):
""" OAuth 2.0 client object
"""
def __init__(self, auth_endpoint=None, token_endpoint=None,
resource_endpoint=None, client_id=None, client_secret=None,
redirect_uri=None, token_transport=None):
assert(hasattr(token_transport, '__call__') or
token_transport in ('headers', 'query', None))
self.auth_endpoint = auth_endpoint
self.token_endpoint = token_endpoint
self.resource_endpoint = resource_endpoint
self.redirect_uri = redirect_uri
self.client_id = client_id
self.client_secret = client_secret
self.access_token = None
self.token_transport = token_transport or 'query'
self.token_expires = -1
self.refresh_token = None
def auth_uri(self, scope=None, scope_delim=None, state=None, **kwargs):
""" Builds the auth URI for the authorization endpoint
"""
scope_delim = scope_delim and scope_delim or ' '
kwargs.update({
'client_id': self.client_id,
'response_type': 'code',
})
if scope is not None:
kwargs['scope'] = scope_delim.join(scope)
if state is not None:
kwargs['state'] = state
if self.redirect_uri is not None:
kwargs['redirect_uri'] = self.redirect_uri
return '%s?%s' % (self.auth_endpoint, urlencode(kwargs))
def request_token(self, parser=None, exclude=None, **kwargs):
""" Request an access token from the token endpoint.
This is largely a helper method and expects the client code to
understand what the server expects. Anything that's passed into
``**kwargs`` will be sent (``urlencode``d) to the endpoint. Client
secret and client ID are automatically included, so are not required
as kwargs. For example::
# if requesting access token from auth flow:
{
'code': rval_from_auth,
}
# if refreshing access token:
{
'refresh_token': stored_refresh_token,
'grant_type': 'refresh_token',
}
:param exclude: An iterable of fields to exclude from the ``POST``
data. This is useful for fields such as ``redirect_uri``
that are required during initial code/token exchange,
but will cause errors with some providers when
exchanging refresh tokens for new access tokens.
:param parser: Callback to deal with returned data. Not all providers
use JSON.
"""
kwargs = kwargs and kwargs or {}
exclude = exclude or {}
parser = parser and parser or loads
kwargs.update({
'client_id': self.client_id,
'client_secret': self.client_secret,
'grant_type': 'grant_type' in kwargs and kwargs['grant_type'] or \
'authorization_code'
})
if self.redirect_uri is not None and 'redirect_uri' not in exclude:
kwargs.update({'redirect_uri': self.redirect_uri})
msg = urlopen(self.token_endpoint, urlencode(kwargs).encode(
'utf-8'))
data = parser(msg.read().decode(msg.info().get_content_charset() or
'utf-8'))
for key in data:
setattr(self, key, data[key])
# expires_in is RFC-compliant. if anything else is used by the
# provider, token_expires must be set manually
if hasattr(self, 'expires_in'):
self.token_expires = mktime((datetime.utcnow() + timedelta(
seconds=self.expires_in)).timetuple())
assert(self.access_token is not None)
def refresh(self):
assert self.refresh_token is not None
self.request_token(refresh_token=self.refresh_token,
grant_type='refresh_token', exclude=('redirect_uri',))
def request(self, url, method=None, data=None, parser=None):
""" Request user data from the resource endpoint
:param url: The path to the resource and querystring if required
:param method: HTTP method. Defaults to ``GET`` unless data is not None
in which case it defaults to ``POST``
:param data: Data to be POSTed to the resource endpoint
:param parser: Parser callback to deal with the returned data. Defaults
to ``json.loads`.`
"""
assert(self.access_token is not None)
parser = parser or loads
if not method:
method = 'GET' if not data else 'POST'
if not hasattr(self.token_transport, '__call__'):
transport = globals()['_transport_{0}'.format(self.token_transport)]
else:
transport = self.token_transport
req = transport('{0}{1}'.format(self.resource_endpoint,
url), self.access_token, data=data, method=method)
resp = urlopen(req)
data = resp.read()
try:
# try to decode it first using either the content charset, falling
# back to utf8
return parser(data.decode(resp.info().get_content_charset() or
'utf-8'))
except UnicodeDecodeError:
# if we've gotten a decoder error, the calling code better know how
# to deal with it. some providers (i.e. stackexchange) like to gzip
# their responses, so this allows the client code to handle it
# directly.
return parser(data)
def _transport_headers(url, access_token, data=None, method=None):
try:
req = Request(url, data=data, method=method)
except TypeError:
req = Request(url, data=data)
req.get_method = lambda: method
req.headers.update({
'Authorization': 'Bearer {0}'.format(access_token)
})
return req
def _transport_query(url, access_token, data=None, method=None):
parts = urlsplit(url)
query = dict(parse_qsl(parts.query))
query.update({
'access_token': access_token
})
url = urlunsplit((parts.scheme, parts.netloc, parts.path,
urlencode(query), parts.fragment))
try:
req = Request(url, data=data, method=method)
except TypeError:
req = Request(url, data=data)
req.get_method = lambda: method
return req