#!/usr/bin/python -u
# Copyright (C) 2014 Bashton Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

import urllib2
import urlparse
import time
import hashlib
import hmac
import json
import sys
import os
from configobj import ConfigObj

import syslog


class AWSCredentials(object):
    """
    Class for dealing with IAM role credentials from meta-data server and later
    on to deal with boto/aws config provided keys
    """
    def __init__(self):
        host = 'http://169.254.169.254'
        path = '/latest/meta-data/iam/security-credentials/'
        self.meta_data_uri = urlparse.urljoin(host, path)

    def __get_role(self):
        # Read IAM role from AWS metadata store
        request = urllib2.Request(self.meta_data_uri)

        response = None
        try:
            response = urllib2.urlopen(request, None, 5)
            self.iamrole = response.read()
        except urllib2.URLError as e:
            if hasattr(e, 'reason'):
                raise Exception("URL error reason: %s, probable cause is that\
 you don't have IAM role on this machine" % e.reason)
            elif hasattr(e, 'code'):
                raise Exception("Server error code: %s" % e.code)
        finally:
            if response:
                response.close()

    def __load_config(self):
        """
        Loading config file from predefined location.
        Example config file content:
            AccessKeyId = mykey
            SecretAccessKey = mysecretkey
            Token = '' # this can/have to be empty
        """
        _CONF_FILE = '/etc/apt/s3auth.conf'

        # Checking if 'file' exists, if it does read it
        if os.path.isfile(os.path.expanduser(_CONF_FILE)):
            config = ConfigObj(os.path.expanduser(_CONF_FILE))
            return config
        else:
            raise Exception("Config file: %s doesn't exist" % _CONF_FILE)
            syslog.syslog("Config file: %s doesn't exist" % _CONF_FILE)

    def get_credentials(self):
        """
        Read IAM credentials from AWS metadata store.
        Note: This method should be explicitly called after constructing new
            object, as in 'explicit is better than implicit'.
        """
        data = None

        try:
            data = self.__load_config()
        except:
            pass

        if data is None:
            self.__get_role()
            request = urllib2.Request(
                urlparse.urljoin(self.meta_data_uri, self.iamrole)
                )

            response = None

            try:
                response = urllib2.urlopen(request, None, 30)
                data = json.loads(response.read())
            except urllib2.URLError as e:
                if hasattr(e, 'reason'):
                    raise Exception("URL error reason: %s" % e.reason)
                elif hasattr(e, 'code'):
                    raise Exception("Server error code: %s" % e.code)
            finally:
                if response:
                    response.close()

        self.access_key = data['AccessKeyId']
        self.secret_key = data['SecretAccessKey']
        self.token = data['Token']

    def sign(self, request, timeval=None):
        """
        Attach a valid S3 signature to request.
        request - instance of Request
        """
        date = time.strftime("%a, %d %b %Y %H:%M:%S GMT",
                             timeval or time.gmtime())
        request.add_header('Date', date)
        host = request.get_host()

        # TODO: bucket name finding is ugly, I should find a way to support
        # both naming conventions: http://bucket.s3.amazonaws.com/ and
        # http://s3.amazonaws.com/bucket/
        try:
            pos = host.find(".s3")
            assert pos != -1
            bucket = host[:pos]
        except:
            raise Exception("Can't establish bucket name based on the hostname:\
              %s" % host)

        resource = "/%s%s" % (bucket, request.get_selector(), )
        amz_headers = 'x-amz-security-token:%s\n' % self.token
        sigstring = ("%(method)s\n\n\n%(date)s\n"
                     "%(canon_amzn_headers)s%(canon_amzn_resource)s") % ({
                         'method': request.get_method(),
                         'date': request.headers.get('Date'),
                         'canon_amzn_headers': amz_headers,
                         'canon_amzn_resource': resource})
        digest = hmac.new(
            str(self.secret_key),
            str(sigstring),
            hashlib.sha1).digest()
        signature = digest.encode('base64').strip()
        return signature

    def urlopen(self, url, **kwargs):
        """urlopen(url) open the remote file and return a file object."""
        try:
            return urllib2.urlopen(self._request(url), None, 30)
        except urllib2.HTTPError as e:
            # HTTPError is a "file like object" similar to what
            # urllib2.urlopen returns, so return it and let caller
            # deal with the error code
            return e
        # For other errors, throw an exception directly
        except urllib2.URLError as e:
            if hasattr(e, 'reason'):
                raise Exception("URL error reason: %s" % e.reason)
            elif hasattr(e, 'code'):
                raise Exception("Server error code: %s" % e.code)
        except urllib2.socket.timeout:
            raise Exception("Socket timeout")

    def _request(self, url):
        request = urllib2.Request(url)
        request.add_header('x-amz-security-token', self.token)
        signature = self.sign(request)
        request.add_header(
            'Authorization', "AWS {0}:{1}".format(
                self.access_key,
                signature
            ).rstrip()
        )
        return request


class APTMessage(object):
    MESSAGE_CODES = {
        100: 'Capabilities',
        102: 'Status',
        200: 'URI Start',
        201: 'URI Done',
        400: 'URI Failure',
        600: 'URI Acquire',
        601: 'Configuration',
    }

    def __init__(self, code, headers):
        self.code = code
        self.headers = headers

    def process(self, lines):
        status_line = lines.pop(0)
        self.code = int(status_line.split()[0])
        self.headers = []
        for line in lines:
            line = line.strip()
            if not line:
                continue
            parts = [p.strip() for p in line.split(':', 1)]
            if len(parts) != 2:
                continue
            self.headers.append(parts)
        return self(code, headers)

    def encode(self):
        result = '{0} {1}\n'.format(self.code, self.MESSAGE_CODES[self.code])
        for item in self.headers.keys():
            if self.headers[item] is not None:
                result += '{0}: {1}\n'.format(item, self.headers[item])
        return result + '\n'


class S3_method(object):
    __eof = False

    def __init__(self):
        self.iam = AWSCredentials()
        self.iam.get_credentials()
        self.send_capabilities()

    def fail(self, message='Failed'):
        self.send_uri_failure({'URI': self.uri, 'Message': message})

    def _read_message(self):
        """
        Apt uses for communication with its methods the text protocol similar
        to http. This function parses the protocol messages from stdin.
        """
        if self.__eof:
            return None
        result = {}
        line = sys.stdin.readline()
        while line == '\n':
            line = sys.stdin.readline()
        if not line:
            self.__eof = True
            return None
        s = line.split(" ", 1)
        result['_number'] = int(s[0])
        result['_text'] = s[1].strip()

        while not self.__eof:
            line = sys.stdin.readline()
            if not line:
                self.__eof = True
                return result
            if line == '\n':
                return result
            s = line.split(":", 1)
            result[s[0]] = s[1].strip()

    def send(self, code, headers):
        message = APTMessage(code, headers)
        sys.stdout.write(message.encode())

    def send_capabilities(self):
        self.send(100, {'Version': '1.0', 'Single-Instance': 'true'})

    def send_status(self, headers):
        self.send(102, headers)

    def send_uri_start(self, headers):
        self.send(200, headers)

    def send_uri_done(self, headers):
        self.send(201, headers)

    def send_uri_failure(self, headers):
        self.send(400, headers)

    def run(self):
        """Loop through requests on stdin"""
        while True:
            message = self._read_message()
            if message is None:
                return 0
            if message['_number'] == 600:
                try:
                    self.fetch(message)
                except Exception, e:
                    self.fail(e.__class__.__name__ + ": " + str(e))
            else:
                return 100

    # We need to be able to quote specific characters to support S3
    # lookups, something urllib and friends don't do easily
    def quote(self, s, unsafe):
        res = list(s)
        for i in range(len(res)):
            c = res[i]
            if c in unsafe:
                res[i] = '%%%02X' % ord(c)
        return ''.join(res)

    def fetch(self, msg):
        self.uri = msg['URI']
        self.uri_parsed = urlparse.urlparse(self.uri)
        # quote path for +, ~, and spaces
        # see bugs.launchpad.net #1003633 and #1086997
        self.uri_updated = 'https://' + self.uri_parsed.netloc +\
            self.quote(self.uri_parsed.path, '+~ ')
        self.filename = msg['Filename']

        response = self.iam.urlopen(self.uri_updated)
        self.send_status({'URI': self.uri, 'Message': 'Waiting for headers'})

        if response.code != 200:
            self.send_uri_failure({
                'URI': self.uri,
                'Message': str(response.code) + '  ' + response.msg,
                'FailReason': 'HttpError' + str(response.code)})
            while True:
                data = response.read(4096)
                if not len(data):
                    break
            response.close()
            return

        self.send_uri_start({
            'URI': self.uri,
            'Size': response.headers.getheader('content-length'),
            'Last-Modified': response.headers.getheader('last-modified')})

        f = open(self.filename, "w")
        hash_sha256 = hashlib.sha256()
        hash_sha512 = hashlib.sha512()
        hash_md5 = hashlib.md5()
        while True:
            data = response.read(4096)
            if not len(data):
                break
            hash_sha256.update(data)
            hash_sha512.update(data)
            hash_md5.update(data)
            f.write(data)
        response.close()
        f.close()

        self.send_uri_done({
            'URI': self.uri,
            'Filename': self.filename,
            'Size': response.headers.getheader('content-length'),
            'Last-Modified': response.headers.getheader('last-modified'),
            'MD5-Hash': hash_md5.hexdigest(),
            'MD5Sum-Hash': hash_md5.hexdigest(),
            'SHA256-Hash': hash_sha256.hexdigest(),
            'SHA512-Hash': hash_sha512.hexdigest()})

if __name__ == '__main__':
    try:
        method = S3_method()
        ret = method.run()
        sys.exit(ret)
    except KeyboardInterrupt:
        pass
