1 /* Copyright 2019 The Chromium OS Authors. All rights reserved.
2  * Use of this source code is governed by a BSD-style license that can be
3  * found in the LICENSE file.
4  */
5 
6 #include <syslog.h>
7 
8 #include "cras_iodev_list.h"
9 #include "cras_messages.h"
10 #include "cras_observer.h"
11 #include "cras_rclient.h"
12 #include "cras_rclient_util.h"
13 #include "cras_rstream.h"
14 #include "cras_server_metrics.h"
15 #include "cras_system_state.h"
16 #include "cras_tm.h"
17 #include "cras_types.h"
18 #include "cras_util.h"
19 #include "stream_list.h"
20 
rclient_send_message_to_client(const struct cras_rclient * client,const struct cras_client_message * msg,int * fds,unsigned int num_fds)21 int rclient_send_message_to_client(const struct cras_rclient *client,
22 				   const struct cras_client_message *msg,
23 				   int *fds, unsigned int num_fds)
24 {
25 	return cras_send_with_fds(client->fd, (const void *)msg, msg->length,
26 				  fds, num_fds);
27 }
28 
rclient_destroy(struct cras_rclient * client)29 void rclient_destroy(struct cras_rclient *client)
30 {
31 	cras_observer_remove(client->observer);
32 	stream_list_rm_all_client_streams(cras_iodev_list_get_stream_list(),
33 					  client);
34 	free(client);
35 }
36 
rclient_validate_message_fds(const struct cras_server_message * msg,int * fds,unsigned int num_fds)37 int rclient_validate_message_fds(const struct cras_server_message *msg,
38 				 int *fds, unsigned int num_fds)
39 {
40 	switch (msg->id) {
41 	case CRAS_SERVER_CONNECT_STREAM:
42 		if (num_fds > 2)
43 			goto error;
44 		break;
45 	case CRAS_SERVER_SET_AEC_DUMP:
46 		if (num_fds > 1)
47 			goto error;
48 		break;
49 	default:
50 		if (num_fds > 0)
51 			goto error;
52 		break;
53 	}
54 
55 	return 0;
56 
57 error:
58 	syslog(LOG_ERR, "Message %d should not have %u fds attached.", msg->id,
59 	       num_fds);
60 	return -EINVAL;
61 }
62 
63 static int
rclient_validate_stream_connect_message(const struct cras_rclient * client,const struct cras_connect_message * msg)64 rclient_validate_stream_connect_message(const struct cras_rclient *client,
65 					const struct cras_connect_message *msg)
66 {
67 	if (!cras_valid_stream_id(msg->stream_id, client->id)) {
68 		syslog(LOG_ERR,
69 		       "stream_connect: invalid stream_id: %x for "
70 		       "client: %zx.\n",
71 		       msg->stream_id, client->id);
72 		return -EINVAL;
73 	}
74 
75 	int direction = cras_stream_direction_mask(msg->direction);
76 	if (direction < 0 || !(client->supported_directions & direction)) {
77 		syslog(LOG_ERR,
78 		       "stream_connect: invalid stream direction: %x for "
79 		       "client: %zx.\n",
80 		       msg->direction, client->id);
81 		return -EINVAL;
82 	}
83 
84 	if (!cras_validate_client_type(msg->client_type)) {
85 		syslog(LOG_ERR,
86 		       "stream_connect: invalid stream client_type: %x for "
87 		       "client: %zx.\n",
88 		       msg->client_type, client->id);
89 	}
90 	return 0;
91 }
92 
rclient_validate_stream_connect_fds(int audio_fd,int client_shm_fd,size_t client_shm_size)93 static int rclient_validate_stream_connect_fds(int audio_fd, int client_shm_fd,
94 					       size_t client_shm_size)
95 {
96 	/* check audio_fd is valid. */
97 	if (audio_fd < 0) {
98 		syslog(LOG_ERR, "Invalid audio fd in stream connect.\n");
99 		return -EBADF;
100 	}
101 
102 	/* check client_shm_fd is valid if client wants to use client shm. */
103 	if (client_shm_size > 0 && client_shm_fd < 0) {
104 		syslog(LOG_ERR,
105 		       "client_shm_fd must be valid if client_shm_size > 0.\n");
106 		return -EBADF;
107 	} else if (client_shm_size == 0 && client_shm_fd >= 0) {
108 		syslog(LOG_ERR,
109 		       "client_shm_fd can be valid only if client_shm_size > 0.\n");
110 		return -EINVAL;
111 	}
112 	return 0;
113 }
114 
rclient_validate_stream_connect_params(const struct cras_rclient * client,const struct cras_connect_message * msg,int audio_fd,int client_shm_fd)115 int rclient_validate_stream_connect_params(
116 	const struct cras_rclient *client,
117 	const struct cras_connect_message *msg, int audio_fd, int client_shm_fd)
118 {
119 	int rc;
120 
121 	rc = rclient_validate_stream_connect_message(client, msg);
122 	if (rc)
123 		return rc;
124 
125 	rc = rclient_validate_stream_connect_fds(audio_fd, client_shm_fd,
126 						 msg->client_shm_size);
127 	if (rc)
128 		return rc;
129 
130 	return 0;
131 }
132 
rclient_handle_client_stream_connect(struct cras_rclient * client,const struct cras_connect_message * msg,int aud_fd,int client_shm_fd)133 int rclient_handle_client_stream_connect(struct cras_rclient *client,
134 					 const struct cras_connect_message *msg,
135 					 int aud_fd, int client_shm_fd)
136 {
137 	struct cras_rstream *stream;
138 	struct cras_client_stream_connected stream_connected;
139 	struct cras_client_message *reply;
140 	struct cras_audio_format remote_fmt;
141 	struct cras_rstream_config stream_config;
142 	int rc, header_fd, samples_fd;
143 	size_t samples_size;
144 	int stream_fds[2];
145 
146 	rc = rclient_validate_stream_connect_params(client, msg, aud_fd,
147 						    client_shm_fd);
148 	remote_fmt = unpack_cras_audio_format(&msg->format);
149 	if (rc == 0 && !cras_audio_format_valid(&remote_fmt)) {
150 		rc = -EINVAL;
151 	}
152 	if (rc) {
153 		if (client_shm_fd >= 0)
154 			close(client_shm_fd);
155 		if (aud_fd >= 0)
156 			close(aud_fd);
157 		goto reply_err;
158 	}
159 
160 	/* When full, getting an error is preferable to blocking. */
161 	cras_make_fd_nonblocking(aud_fd);
162 
163 	stream_config = cras_rstream_config_init_with_message(
164 		client, msg, &aud_fd, &client_shm_fd, &remote_fmt);
165 	/* Overwrite client_type if client->client_type is set. */
166 	if (client->client_type != CRAS_CLIENT_TYPE_UNKNOWN)
167 		stream_config.client_type = client->client_type;
168 	rc = stream_list_add(cras_iodev_list_get_stream_list(), &stream_config,
169 			     &stream);
170 	if (rc)
171 		goto cleanup_config;
172 
173 	detect_rtc_stream_pair(cras_iodev_list_get_stream_list(), stream);
174 
175 	/* Tell client about the stream setup. */
176 	syslog(LOG_DEBUG, "Send connected for stream %x\n", msg->stream_id);
177 
178 	// Check that shm size is at most UINT32_MAX for non-shm streams.
179 	samples_size = cras_rstream_get_samples_shm_size(stream);
180 	if (samples_size > UINT32_MAX && stream_config.client_shm_fd < 0) {
181 		syslog(LOG_ERR,
182 		       "Non client-provided shm stream has samples shm larger "
183 		       "than uint32_t: %zu",
184 		       samples_size);
185 		if (aud_fd >= 0)
186 			close(aud_fd);
187 		rc = -EINVAL;
188 		goto cleanup_config;
189 	}
190 	cras_fill_client_stream_connected(&stream_connected, 0, /* No error. */
191 					  msg->stream_id, &remote_fmt,
192 					  samples_size,
193 					  cras_rstream_get_effects(stream));
194 	reply = &stream_connected.header;
195 
196 	rc = cras_rstream_get_shm_fds(stream, &header_fd, &samples_fd);
197 	if (rc)
198 		goto cleanup_config;
199 
200 	stream_fds[0] = header_fd;
201 	/* If we're using client-provided shm, samples_fd here refers to the
202 	 * same shm area as client_shm_fd */
203 	stream_fds[1] = samples_fd;
204 
205 	rc = client->ops->send_message_to_client(client, reply, stream_fds, 2);
206 	if (rc < 0) {
207 		syslog(LOG_ERR, "Failed to send connected messaged\n");
208 		stream_list_rm(cras_iodev_list_get_stream_list(),
209 			       stream->stream_id);
210 		goto cleanup_config;
211 	}
212 
213 	/* Cleanup local object explicitly. */
214 	cras_rstream_config_cleanup(&stream_config);
215 	return 0;
216 
217 cleanup_config:
218 	cras_rstream_config_cleanup(&stream_config);
219 
220 reply_err:
221 	/* Send the error code to the client. */
222 	cras_fill_client_stream_connected(&stream_connected, rc, msg->stream_id,
223 					  &remote_fmt, 0, msg->effects);
224 	reply = &stream_connected.header;
225 	client->ops->send_message_to_client(client, reply, NULL, 0);
226 
227 	return rc;
228 }
229 
230 /* Handles messages from the client requesting that a stream be removed from the
231  * server. */
rclient_handle_client_stream_disconnect(struct cras_rclient * client,const struct cras_disconnect_stream_message * msg)232 int rclient_handle_client_stream_disconnect(
233 	struct cras_rclient *client,
234 	const struct cras_disconnect_stream_message *msg)
235 {
236 	if (!cras_valid_stream_id(msg->stream_id, client->id)) {
237 		syslog(LOG_ERR,
238 		       "stream_disconnect: invalid stream_id: %x for "
239 		       "client: %zx.\n",
240 		       msg->stream_id, client->id);
241 		return -EINVAL;
242 	}
243 	return stream_list_rm(cras_iodev_list_get_stream_list(),
244 			      msg->stream_id);
245 }
246 
247 /* Creates a client structure and sends a message back informing the client that
248  * the connection has succeeded. */
rclient_generic_create(int fd,size_t id,const struct cras_rclient_ops * ops,int supported_directions)249 struct cras_rclient *rclient_generic_create(int fd, size_t id,
250 					    const struct cras_rclient_ops *ops,
251 					    int supported_directions)
252 {
253 	struct cras_rclient *client;
254 	struct cras_client_connected msg;
255 	int state_fd;
256 
257 	client = (struct cras_rclient *)calloc(1, sizeof(struct cras_rclient));
258 	if (!client)
259 		return NULL;
260 
261 	client->fd = fd;
262 	client->id = id;
263 	client->ops = ops;
264 	client->supported_directions = supported_directions;
265 
266 	cras_fill_client_connected(&msg, client->id);
267 	state_fd = cras_sys_state_shm_fd();
268 	client->ops->send_message_to_client(client, &msg.header, &state_fd, 1);
269 
270 	return client;
271 }
272 
273 /* A generic entry point for handling a message from the client. Called from
274  * the main server context. */
rclient_handle_message_from_client(struct cras_rclient * client,const struct cras_server_message * msg,int * fds,unsigned int num_fds)275 int rclient_handle_message_from_client(struct cras_rclient *client,
276 				       const struct cras_server_message *msg,
277 				       int *fds, unsigned int num_fds)
278 {
279 	int rc = 0;
280 	assert(client && msg);
281 
282 	rc = rclient_validate_message_fds(msg, fds, num_fds);
283 	if (rc < 0) {
284 		for (int i = 0; i < (int)num_fds; i++)
285 			if (fds[i] >= 0)
286 				close(fds[i]);
287 		return rc;
288 	}
289 	int fd = num_fds > 0 ? fds[0] : -1;
290 
291 	switch (msg->id) {
292 	case CRAS_SERVER_CONNECT_STREAM: {
293 		int client_shm_fd = num_fds > 1 ? fds[1] : -1;
294 		if (MSG_LEN_VALID(msg, struct cras_connect_message)) {
295 			rclient_handle_client_stream_connect(
296 				client,
297 				(const struct cras_connect_message *)msg, fd,
298 				client_shm_fd);
299 		} else {
300 			return -EINVAL;
301 		}
302 		break;
303 	}
304 	case CRAS_SERVER_DISCONNECT_STREAM:
305 		if (!MSG_LEN_VALID(msg, struct cras_disconnect_stream_message))
306 			return -EINVAL;
307 		rclient_handle_client_stream_disconnect(
308 			client,
309 			(const struct cras_disconnect_stream_message *)msg);
310 		break;
311 	default:
312 		break;
313 	}
314 
315 	return rc;
316 }
317