summaryrefslogtreecommitdiffstats
path: root/utils/interfaces.py
blob: 7a0e3a9c555bb7acddb95e7b9ff4133f58f9792a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
# -------------------------------------------------------------------------
#   Copyright (c) 2015-2017 AT&T Intellectual Property
#
#   Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.
#
# -------------------------------------------------------------------------
#

import requests

from osdf.config.base import osdf_config, creds_prefixes
from osdf.logging.osdf_logging import MH, debug_log


def get_rest_client(request_json, service):
    """Get a RestClient based on request_json's callback URL and osdf_config's credentials based on service name
    :param request_json:
    :param service: so or cm
    :return: rc -- RestClient
    """
    callback_url = request_json["requestInfo"]["callbackUrl"]
    prefix = creds_prefixes[service]
    config = osdf_config.deployment
    c_userid, c_passwd = config[prefix + "Username"], config[prefix + "Password"]
    return RestClient(url=callback_url, userid=c_userid, passwd=c_passwd)


class RestClient(object):
    """Simple REST Client that supports get/post and basic auth"""

    def __init__(self, userid=None, passwd=None, log_func=None, url=None, timeout=None, headers=None,
                 method="POST", req_id=None):
        self.auth = (userid, passwd) if userid and passwd else None
        self.headers = headers if headers else {}
        self.method = method
        self.url = url
        self.log_func = log_func
        self.timeout = (30, 90) if timeout is None else timeout
        self.req_id = req_id

    def add_headers(self, headers):
        self.headers.update(headers)

    def request(self, url=None, method=None, asjson=True, ok_codes=(2, ),
                raw_response=False, noresponse=False, timeout=None, **kwargs):
        """
        :param url: REST end point to query
        :param method: GET or POST (default is None => self.method)
        :param asjson: whether the expected response is in json format
        :param ok_codes: expected codes (prefix matching -- e.g. can be (20, 21, 32) or (2, 3))
        :param noresponse: If no response is expected (as long as response codes are OK)
        :param raw_response: If we need just the raw response (e.g. conductor sends transaction IDs in headers)
        :param timeout: Connection and read timeouts
        :param kwargs: Other parameters
        :return:
        """
        if not self.req_id:
            debug_log.debug("Requesting URL: {}".format(url or self.url))
        else:
            debug_log.debug("Requesting URL: {} for request ID: {}".format(url or self.url, self.req_id))

        res = requests.request(url=url or self.url, method=method or self.method,
                               auth=self.auth, headers=self.headers,
                               timeout=timeout or self.timeout, **kwargs)

        if self.log_func:
            self.log_func(MH.received_http_response(res))

        res_code = str(res.status_code)
        if not any(res_code.startswith(x) for x in map(str, ok_codes)):
            raise res.raise_for_status()

        if raw_response:
            return res
        elif noresponse:
            return None
        elif asjson:
            return res.json()
        else:
            return res.content