summaryrefslogtreecommitdiffstats
path: root/src/main/java/com/thinkaurelius/titan/diskstorage/cassandra/thrift/thriftpool/CTConnectionFactory.java
blob: 1c60cfdd3b0f559f313affccb313c00d2b1972f9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
package com.thinkaurelius.titan.diskstorage.cassandra.thrift.thriftpool;

import org.apache.cassandra.auth.IAuthenticator;
import org.apache.cassandra.thrift.*;
import org.apache.commons.lang.StringUtils;
import org.apache.commons.pool.KeyedPoolableObjectFactory;
import org.apache.thrift.protocol.TBinaryProtocol;
import org.apache.thrift.transport.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.*;
import java.util.concurrent.atomic.AtomicReference;

/**
 * A factory compatible with Apache commons-pool for Cassandra Thrift API
 * connections.
 *
 * @author Dan LaRocque <dalaro@hopcount.org>
 */
public class CTConnectionFactory implements KeyedPoolableObjectFactory<String, CTConnection> {

    private static final Logger log = LoggerFactory.getLogger(CTConnectionFactory.class);
    private static final long SCHEMA_WAIT_MAX = 5000L;
    private static final long SCHEMA_WAIT_INCREMENT = 25L;

    private final AtomicReference<Config> cfgRef;

    private CTConnectionFactory(Config config) {
        this.cfgRef = new AtomicReference<Config>(config);
    }

    @Override
    public void activateObject(String key, CTConnection c) throws Exception {
        // Do nothing, as in passivateObject
    }

    @Override
    public void destroyObject(String key, CTConnection c) throws Exception {
        TTransport t = c.getTransport();

        if (t.isOpen()) {
            t.close();
            log.trace("Closed transport {}", t);
        } else {
            log.trace("Not closing transport {} (already closed)", t);
        }
    }

    @Override
    public CTConnection makeObject(String key) throws Exception {
        CTConnection conn = makeRawConnection();
        Cassandra.Client client = conn.getClient();
        client.set_keyspace(key);

        return conn;
    }

    /**
     * Create a Cassandra-Thrift connection, but do not attempt to
     * set a keyspace on the connection.
     *
     * @return A CTConnection ready to talk to a Cassandra cluster
     * @throws TTransportException on any Thrift transport failure
     */
    public CTConnection makeRawConnection() throws TTransportException {
        final Config cfg = cfgRef.get();

        String hostname = cfg.getRandomHost();

        log.debug("Creating TSocket({}, {}, {}, {}, {})", hostname, cfg.port, cfg.username, cfg.password, cfg.timeoutMS);

        TSocket socket;
        if (null != cfg.sslTruststoreLocation && !cfg.sslTruststoreLocation.isEmpty()) {
            TSSLTransportFactory.TSSLTransportParameters params = new TSSLTransportFactory.TSSLTransportParameters() {{
               setTrustStore(cfg.sslTruststoreLocation, cfg.sslTruststorePassword);
            }};
            socket = TSSLTransportFactory.getClientSocket(hostname, cfg.port, cfg.timeoutMS, params);
        } else {
            socket = new TSocket(hostname, cfg.port, cfg.timeoutMS);
        }

        TTransport transport = new TFramedTransport(socket, cfg.frameSize);
        log.trace("Created transport {}", transport);
        TBinaryProtocol protocol = new TBinaryProtocol(transport);
        Cassandra.Client client = new Cassandra.Client(protocol);
        if (!transport.isOpen()) {
            transport.open();
        }

        if (cfg.username != null) {
            Map<String, String> credentials = new HashMap<String, String>() {{
                put(IAuthenticator.USERNAME_KEY, cfg.username);
                put(IAuthenticator.PASSWORD_KEY, cfg.password);
            }};

            try {
                client.login(new AuthenticationRequest(credentials));
            } catch (Exception e) { // TTransportException will propagate authentication/authorization failure
                throw new TTransportException(e);
            }
        }
        return new CTConnection(transport, client, cfg);
    }

    @Override
    public void passivateObject(String key, CTConnection o) throws Exception {
        // Do nothing, as in activateObject
    }

    @Override
    public boolean validateObject(String key, CTConnection c) {
        Config curCfg = cfgRef.get();

        boolean isSameConfig = c.getConfig().equals(curCfg);
        if (log.isDebugEnabled()) {
            if (isSameConfig) {
                log.trace("Validated {} by configuration {}", c, curCfg);
            } else {
                log.trace("Rejected {}; current config is {}; rejected connection config is {}",
                          c, curCfg, c.getConfig());
            }
        }

        return isSameConfig && c.isOpen();
    }

    public static class Config {

        private final String[] hostnames;
        private final int port;
        private final String username;
        private final String password;
        private final Random random;

        private int timeoutMS;
        private int frameSize;

        private String sslTruststoreLocation;
        private String sslTruststorePassword;

        private boolean isBuilt;

        public Config(String[] hostnames, int port, String username, String password) {
            this.hostnames = hostnames;
            this.port = port;
            this.username = username;
            this.password = password;
            this.random = new Random();
        }

        // TODO: we don't really need getters/setters here as all of the fields are final and immutable

        public String getHostname() {
            return hostnames[0];
        }

        public int getPort() {
            return port;
        }

        public String getRandomHost() {
            return hostnames.length == 1 ? hostnames[0] : hostnames[random.nextInt(hostnames.length)];
        }

        public Config setTimeoutMS(int timeoutMS) {
            checkIfAlreadyBuilt();
            this.timeoutMS = timeoutMS;
            return this;
        }

        public Config setFrameSize(int frameSize) {
            checkIfAlreadyBuilt();
            this.frameSize = frameSize;
            return this;
        }

        public Config setSSLTruststoreLocation(String location) {
            checkIfAlreadyBuilt();
            this.sslTruststoreLocation = location;
            return this;
        }

        public Config setSSLTruststorePassword(String password) {
            checkIfAlreadyBuilt();
            this.sslTruststorePassword = password;
            return this;
        }

        public CTConnectionFactory build() {
            isBuilt = true;
            return new CTConnectionFactory(this);
        }


        public void checkIfAlreadyBuilt() {
            if (isBuilt)
                throw new IllegalStateException("Can't accept modifications when used with built factory.");
        }

        @Override
        public String toString() {
            return "Config[hostnames=" + StringUtils.join(hostnames, ',') + ", port=" + port
                    + ", timeoutMS=" + timeoutMS + ", frameSize=" + frameSize
                    + "]";
        }
    }

}