Skip to content

Commit e563945

Browse files
committed
Address feedback
1 parent 460e149 commit e563945

File tree

1 file changed

+46
-58
lines changed

1 file changed

+46
-58
lines changed

vm/src/stdlib/socket.rs

Lines changed: 46 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ impl PySocket {
113113
};
114114
self.kind.set(socket_kind);
115115
self.proto.set(proto);
116-
Socket::new(domain, socket_type, None).map_err(|err| convert_io_error(vm, err))?
116+
Socket::new(domain, socket_type, None).map_err(|err| convert_sock_error(vm, err))?
117117
};
118118
self.sock.replace(sock);
119119
Ok(())
@@ -127,15 +127,15 @@ impl PySocket {
127127
} else {
128128
self.sock().connect(&sock_addr)
129129
};
130-
res.map_err(|err| convert_io_error(vm, err))
130+
res.map_err(|err| convert_sock_error(vm, err))
131131
}
132132

133133
#[pymethod]
134134
fn bind(&self, address: Address, vm: &VirtualMachine) -> PyResult<()> {
135135
let sock_addr = get_addr(vm, address)?;
136136
self.sock()
137137
.bind(&sock_addr)
138-
.map_err(|err| convert_io_error(vm, err))
138+
.map_err(|err| convert_sock_error(vm, err))
139139
}
140140

141141
#[pymethod]
@@ -144,70 +144,51 @@ impl PySocket {
144144
let backlog = if backlog < 0 { 0 } else { backlog };
145145
self.sock()
146146
.listen(backlog)
147-
.map_err(|err| convert_io_error(vm, err))
147+
.map_err(|err| convert_sock_error(vm, err))
148148
}
149149

150150
#[pymethod]
151-
fn _accept(&self, vm: &VirtualMachine) -> PyResult {
151+
fn _accept(&self, vm: &VirtualMachine) -> PyResult<(RawSocket, AddrTuple)> {
152152
let (sock, addr) = self
153153
.sock()
154154
.accept()
155-
.map_err(|err| convert_io_error(vm, err))?;
155+
.map_err(|err| convert_sock_error(vm, err))?;
156156

157-
let fd = vm.new_int(into_sock_fileno(sock));
158-
let addr_tuple = get_addr_tuple(vm, addr);
159-
160-
Ok(vm.ctx.new_tuple(vec![fd, addr_tuple]))
157+
let fd = into_sock_fileno(sock);
158+
Ok((fd, get_addr_tuple(addr)))
161159
}
162160

163161
#[pymethod]
164162
fn recv(&self, bufsize: usize, vm: &VirtualMachine) -> PyResult {
165163
let mut buffer = vec![0u8; bufsize];
166164
match self.sock.borrow_mut().read_exact(&mut buffer) {
167165
Ok(()) => Ok(vm.ctx.new_bytes(buffer)),
168-
Err(ref e) if e.kind() == io::ErrorKind::TimedOut => {
169-
let socket_timeout = vm.class("_socket", "timeout");
170-
Err(vm.new_exception(socket_timeout, "Timed out".to_string()))
171-
}
172-
Err(err) => Err(convert_io_error(vm, err)),
166+
Err(err) => Err(convert_sock_error(vm, err)),
173167
}
174168
}
175169

176170
#[pymethod]
177-
fn recvfrom(&self, bufsize: usize, vm: &VirtualMachine) -> PyResult {
171+
fn recvfrom(&self, bufsize: usize, vm: &VirtualMachine) -> PyResult<(Vec<u8>, AddrTuple)> {
178172
let mut buffer = vec![0u8; bufsize];
179-
let addr = match self.sock().recv_from(&mut buffer) {
180-
Ok((_, addr)) => addr,
181-
Err(ref e) if e.kind() == io::ErrorKind::TimedOut => {
182-
let socket_timeout = vm.class("_socket", "timeout");
183-
return Err(vm.new_exception(socket_timeout, "Timed out".to_string()));
184-
}
185-
Err(err) => return Err(convert_io_error(vm, err)),
186-
};
187-
188-
let addr_tuple = get_addr_tuple(vm, addr);
189-
190-
Ok(vm.ctx.new_tuple(vec![vm.ctx.new_bytes(buffer), addr_tuple]))
173+
match self.sock().recv_from(&mut buffer) {
174+
Ok((_, addr)) => Ok((buffer, get_addr_tuple(addr))),
175+
Err(err) => Err(convert_sock_error(vm, err)),
176+
}
191177
}
192178

193179
#[pymethod]
194180
fn send(&self, bytes: PyBytesLike, vm: &VirtualMachine) -> PyResult<usize> {
195-
match self.sock().send(bytes.to_cow().as_ref()) {
196-
Ok(i) => Ok(i),
197-
Err(ref e) if e.kind() == io::ErrorKind::TimedOut => {
198-
let socket_timeout = vm.class("_socket", "timeout");
199-
Err(vm.new_exception(socket_timeout, "Timed out".to_string()))
200-
}
201-
Err(err) => Err(convert_io_error(vm, err)),
202-
}
181+
self.sock()
182+
.send(bytes.to_cow().as_ref())
183+
.map_err(|err| convert_sock_error(vm, err))
203184
}
204185

205186
#[pymethod]
206187
fn sendto(&self, bytes: PyBytesLike, address: Address, vm: &VirtualMachine) -> PyResult<()> {
207188
let addr = get_addr(vm, address)?;
208189
self.sock()
209190
.send_to(bytes.to_cow().as_ref(), &addr)
210-
.map_err(|err| convert_io_error(vm, err))?;
191+
.map_err(|err| convert_sock_error(vm, err))?;
211192
Ok(())
212193
}
213194

@@ -222,38 +203,38 @@ impl PySocket {
222203
}
223204

224205
#[pymethod]
225-
fn getsockname(&self, vm: &VirtualMachine) -> PyResult {
206+
fn getsockname(&self, vm: &VirtualMachine) -> PyResult<AddrTuple> {
226207
let addr = self
227208
.sock()
228209
.local_addr()
229-
.map_err(|err| convert_io_error(vm, err))?;
210+
.map_err(|err| convert_sock_error(vm, err))?;
230211

231-
Ok(get_addr_tuple(vm, addr))
212+
Ok(get_addr_tuple(addr))
232213
}
233214
#[pymethod]
234-
fn getpeername(&self, vm: &VirtualMachine) -> PyResult {
215+
fn getpeername(&self, vm: &VirtualMachine) -> PyResult<AddrTuple> {
235216
let addr = self
236217
.sock()
237218
.peer_addr()
238-
.map_err(|err| convert_io_error(vm, err))?;
219+
.map_err(|err| convert_sock_error(vm, err))?;
239220

240-
Ok(get_addr_tuple(vm, addr))
221+
Ok(get_addr_tuple(addr))
241222
}
242223

243224
#[pymethod]
244225
fn gettimeout(&self, vm: &VirtualMachine) -> PyResult<Option<f64>> {
245226
let dur = self
246227
.sock()
247228
.read_timeout()
248-
.map_err(|err| convert_io_error(vm, err))?;
229+
.map_err(|err| convert_sock_error(vm, err))?;
249230
Ok(dur.map(|d| d.as_secs_f64()))
250231
}
251232

252233
#[pymethod]
253234
fn setblocking(&self, block: bool, vm: &VirtualMachine) -> PyResult<()> {
254235
self.sock()
255236
.set_nonblocking(!block)
256-
.map_err(|err| convert_io_error(vm, err))
237+
.map_err(|err| convert_sock_error(vm, err))
257238
}
258239

259240
#[pymethod]
@@ -265,10 +246,10 @@ impl PySocket {
265246
fn settimeout(&self, timeout: Option<f64>, vm: &VirtualMachine) -> PyResult<()> {
266247
self.sock()
267248
.set_read_timeout(timeout.map(Duration::from_secs_f64))
268-
.map_err(|err| convert_io_error(vm, err))?;
249+
.map_err(|err| convert_sock_error(vm, err))?;
269250
self.sock()
270251
.set_write_timeout(timeout.map(Duration::from_secs_f64))
271-
.map_err(|err| convert_io_error(vm, err))?;
252+
.map_err(|err| convert_sock_error(vm, err))?;
272253
Ok(())
273254
}
274255

@@ -286,7 +267,7 @@ impl PySocket {
286267
};
287268
self.sock()
288269
.shutdown(how)
289-
.map_err(|err| convert_io_error(vm, err))
270+
.map_err(|err| convert_sock_error(vm, err))
290271
}
291272

292273
#[pyproperty(name = "type")]
@@ -329,19 +310,17 @@ impl TryFromObject for Address {
329310
}
330311
}
331312

332-
fn get_addr_tuple<A: Into<socket2::SockAddr>>(vm: &VirtualMachine, addr: A) -> PyObjectRef {
313+
type AddrTuple = (String, u16);
314+
315+
fn get_addr_tuple<A: Into<socket2::SockAddr>>(addr: A) -> AddrTuple {
333316
let addr = addr.into();
334-
let (port, ip) = if let Some(addr) = addr.as_inet() {
335-
(addr.port(), addr.ip().to_string())
317+
if let Some(addr) = addr.as_inet() {
318+
(addr.ip().to_string(), addr.port())
336319
} else if let Some(addr) = addr.as_inet6() {
337-
(addr.port(), addr.ip().to_string())
320+
(addr.ip().to_string(), addr.port())
338321
} else {
339-
(0, String::new())
340-
};
341-
let port = vm.ctx.new_int(port);
342-
let ip = vm.ctx.new_str(ip);
343-
344-
vm.ctx.new_tuple(vec![ip, port])
322+
(String::new(), 0)
323+
}
345324
}
346325

347326
fn socket_gethostname(vm: &VirtualMachine) -> PyResult {
@@ -438,6 +417,15 @@ fn invalid_sock() -> Socket {
438417
}
439418
}
440419

420+
fn convert_sock_error(vm: &VirtualMachine, err: io::Error) -> PyObjectRef {
421+
if err.kind() == io::ErrorKind::TimedOut {
422+
let socket_timeout = vm.class("_socket", "timeout");
423+
vm.new_exception(socket_timeout, "Timed out".to_string())
424+
} else {
425+
convert_io_error(vm, err)
426+
}
427+
}
428+
441429
pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
442430
let ctx = &vm.ctx;
443431
let socket_timeout = ctx.new_class("socket.timeout", vm.ctx.exceptions.os_error.clone());

0 commit comments

Comments
 (0)