diff --git a/src/System.Management.Automation/engine/remoting/commands/InvokeCommandCommand.cs b/src/System.Management.Automation/engine/remoting/commands/InvokeCommandCommand.cs
index 0b7640c97c3..4850a064cce 100644
--- a/src/System.Management.Automation/engine/remoting/commands/InvokeCommandCommand.cs
+++ b/src/System.Management.Automation/engine/remoting/commands/InvokeCommandCommand.cs
@@ -223,6 +223,7 @@ public override PSCredential Credential
///
[Parameter(ParameterSetName = InvokeCommandCommand.ComputerNameParameterSet)]
[Parameter(ParameterSetName = InvokeCommandCommand.FilePathComputerNameParameterSet)]
+ [Parameter(ParameterSetName = InvokeCommandCommand.SSHHostParameterSet)]
[ValidateRange((Int32)1, (Int32)UInt16.MaxValue)]
public override Int32 Port
{
diff --git a/src/System.Management.Automation/engine/remoting/commands/PSRemotingCmdlet.cs b/src/System.Management.Automation/engine/remoting/commands/PSRemotingCmdlet.cs
index cf94a916d31..2a34eef8c22 100644
--- a/src/System.Management.Automation/engine/remoting/commands/PSRemotingCmdlet.cs
+++ b/src/System.Management.Automation/engine/remoting/commands/PSRemotingCmdlet.cs
@@ -275,6 +275,7 @@ internal struct SSHConnection
public string ComputerName;
public string UserName;
public string KeyFilePath;
+ public int Port;
}
///
@@ -561,6 +562,7 @@ public virtual PSCredential Credential
/// to override the policy setting
///
[Parameter(ParameterSetName = PSRemotingBaseCmdlet.ComputerNameParameterSet)]
+ [Parameter(ParameterSetName = PSRemotingBaseCmdlet.SSHHostParameterSet)]
[ValidateRange((Int32)1, (Int32)UInt16.MaxValue)]
public virtual Int32 Port { get; set; }
@@ -820,6 +822,7 @@ internal static void ValidateSpecifiedAuthentication(PSCredential credential, st
private const string UserNameParameter = "UserName";
private const string KeyFilePathParameter = "KeyFilePath";
private const string IdentityFilePathAlias = "IdentityFilePath";
+ private const string PortParameter = "Port";
#endregion
@@ -851,25 +854,23 @@ internal SSHConnection[] ParseSSHConnectionHashTable()
throw new PSArgumentException(RemotingErrorIdStrings.InvalidSSHConnectionParameter);
}
- string paramValue = item[paramName] as string;
- if (string.IsNullOrEmpty(paramValue))
- {
- throw new PSArgumentException(RemotingErrorIdStrings.InvalidSSHConnectionParameter);
- }
-
if (paramName.Equals(ComputerNameParameter, StringComparison.OrdinalIgnoreCase) || paramName.Equals(HostNameAlias, StringComparison.OrdinalIgnoreCase))
{
- var resolvedComputerName = ResolveComputerName(paramValue);
+ var resolvedComputerName = ResolveComputerName(GetSSHConnectionStringParameter(item[paramName]));
ValidateComputerName(new string[] { resolvedComputerName });
connectionInfo.ComputerName = resolvedComputerName;
}
else if (paramName.Equals(UserNameParameter, StringComparison.OrdinalIgnoreCase))
{
- connectionInfo.UserName = paramValue;
+ connectionInfo.UserName = GetSSHConnectionStringParameter(item[paramName]);
}
else if (paramName.Equals(KeyFilePathParameter, StringComparison.OrdinalIgnoreCase) || paramName.Equals(IdentityFilePathAlias, StringComparison.OrdinalIgnoreCase))
{
- connectionInfo.KeyFilePath = paramValue;
+ connectionInfo.KeyFilePath = GetSSHConnectionStringParameter(item[paramName]);
+ }
+ else if (paramName.Equals(PortParameter, StringComparison.OrdinalIgnoreCase))
+ {
+ connectionInfo.Port = GetSSHConnectionIntParameter(item[paramName]);
}
else
{
@@ -981,6 +982,37 @@ protected void ValidateComputerName(String[] computerNames)
}
}
+ ///
+ /// Validates parameter value and returns as string
+ ///
+ /// Parameter value to be validated
+ /// Parameter value as string
+ private static string GetSSHConnectionStringParameter(object param)
+ {
+ if (param is string paramValue && !string.IsNullOrEmpty(paramValue))
+ {
+ return paramValue;
+ }
+
+ throw new PSArgumentException(RemotingErrorIdStrings.InvalidSSHConnectionParameter);
+ }
+
+
+ ///
+ /// Validates parameter value and returns as integer
+ ///
+ /// Parameter value to be validated
+ /// Parameter value as integer
+ private static int GetSSHConnectionIntParameter(object param)
+ {
+ if (param is int paramValue)
+ {
+ return paramValue;
+ }
+
+ throw new PSArgumentException(RemotingErrorIdStrings.InvalidSSHConnectionParameter);
+ }
+
#endregion Private Methods
#region Overrides
@@ -1312,7 +1344,7 @@ protected void CreateHelpersForSpecifiedSSHComputerNames()
foreach (string computerName in ResolvedComputerNames)
{
- var sshConnectionInfo = new SSHConnectionInfo(this.UserName, computerName, this.KeyFilePath);
+ var sshConnectionInfo = new SSHConnectionInfo(this.UserName, computerName, this.KeyFilePath, this.Port);
var typeTable = TypeTable.LoadDefaultTypeFiles();
var remoteRunspace = RunspaceFactory.CreateRunspace(sshConnectionInfo, this.Host, typeTable) as RemoteRunspace;
var pipeline = CreatePipeline(remoteRunspace);
@@ -1334,7 +1366,8 @@ protected void CreateHelpersForSpecifiedSSHHashComputerNames()
var sshConnectionInfo = new SSHConnectionInfo(
sshConnection.UserName,
sshConnection.ComputerName,
- sshConnection.KeyFilePath);
+ sshConnection.KeyFilePath,
+ sshConnection.Port);
var typeTable = TypeTable.LoadDefaultTypeFiles();
var remoteRunspace = RunspaceFactory.CreateRunspace(sshConnectionInfo, this.Host, typeTable) as RemoteRunspace;
var pipeline = CreatePipeline(remoteRunspace);
diff --git a/src/System.Management.Automation/engine/remoting/commands/PushRunspaceCommand.cs b/src/System.Management.Automation/engine/remoting/commands/PushRunspaceCommand.cs
index 7b74e59ff8e..4703cd1c871 100644
--- a/src/System.Management.Automation/engine/remoting/commands/PushRunspaceCommand.cs
+++ b/src/System.Management.Automation/engine/remoting/commands/PushRunspaceCommand.cs
@@ -1270,7 +1270,7 @@ private RemoteRunspace GetRunspaceForContainerSession()
///
private RemoteRunspace GetRunspaceForSSHSession()
{
- var sshConnectionInfo = new SSHConnectionInfo(this.UserName, ResolveComputerName(HostName), this.KeyFilePath);
+ var sshConnectionInfo = new SSHConnectionInfo(this.UserName, ResolveComputerName(HostName), this.KeyFilePath, this.Port);
var typeTable = TypeTable.LoadDefaultTypeFiles();
var remoteRunspace = RunspaceFactory.CreateRunspace(sshConnectionInfo, this.Host, typeTable) as RemoteRunspace;
remoteRunspace.Open();
diff --git a/src/System.Management.Automation/engine/remoting/commands/newrunspacecommand.cs b/src/System.Management.Automation/engine/remoting/commands/newrunspacecommand.cs
index 12ca75b0a30..94164c5015f 100644
--- a/src/System.Management.Automation/engine/remoting/commands/newrunspacecommand.cs
+++ b/src/System.Management.Automation/engine/remoting/commands/newrunspacecommand.cs
@@ -1078,7 +1078,8 @@ private List CreateRunspacesForSSHHostParameterSet()
var sshConnectionInfo = new SSHConnectionInfo(
this.UserName,
computerName,
- this.KeyFilePath);
+ this.KeyFilePath,
+ this.Port);
var typeTable = TypeTable.LoadDefaultTypeFiles();
remoteRunspaces.Add(RunspaceFactory.CreateRunspace(sshConnectionInfo, this.Host, typeTable) as RemoteRunspace);
}
@@ -1095,7 +1096,8 @@ private List CreateRunspacesForSSHHostHashParameterSet()
var sshConnectionInfo = new SSHConnectionInfo(
sshConnection.UserName,
sshConnection.ComputerName,
- sshConnection.KeyFilePath);
+ sshConnection.KeyFilePath,
+ sshConnection.Port);
var typeTable = TypeTable.LoadDefaultTypeFiles();
remoteRunspaces.Add(RunspaceFactory.CreateRunspace(sshConnectionInfo, this.Host, typeTable) as RemoteRunspace);
}
diff --git a/src/System.Management.Automation/engine/remoting/common/RunspaceConnectionInfo.cs b/src/System.Management.Automation/engine/remoting/common/RunspaceConnectionInfo.cs
index b62337fadec..296565943cf 100644
--- a/src/System.Management.Automation/engine/remoting/common/RunspaceConnectionInfo.cs
+++ b/src/System.Management.Automation/engine/remoting/common/RunspaceConnectionInfo.cs
@@ -355,6 +355,36 @@ internal virtual RunspaceConnectionInfo InternalCopy()
throw new PSNotImplementedException();
}
+ ///
+ /// Validates port number is in range
+ ///
+ /// Port number to validate
+ internal virtual void ValidatePortInRange(int port)
+ {
+ if ((port < MinPort || port > MaxPort))
+ {
+ String message =
+ PSRemotingErrorInvariants.FormatResourceString(
+ RemotingErrorIdStrings.PortIsOutOfRange, port);
+ ArgumentException e = new ArgumentException(message);
+ throw e;
+ }
+ }
+
+ #endregion
+
+ #region Constants
+
+ ///
+ /// Maximum value for port
+ ///
+ protected const int MaxPort = 0xFFFF;
+
+ ///
+ /// Minimum value for port
+ ///
+ protected const int MinPort = 0;
+
#endregion
}
@@ -1172,6 +1202,7 @@ internal void ConstructUri(String scheme, String computerName, Nullable p
if (port.HasValue)
{
+ ValidatePortInRange(port.Value);
// resolve to default ports if required
if (port.Value == DefaultPort)
{
@@ -1185,14 +1216,6 @@ internal void ConstructUri(String scheme, String computerName, Nullable p
PortSetting = port.Value;
UseDefaultWSManPort = false;
}
- else if ((port.Value < MinPort || port.Value > MaxPort))
- {
- String message =
- PSRemotingErrorInvariants.FormatResourceString(
- RemotingErrorIdStrings.PortIsOutOfRange, port);
- ArgumentException e = new ArgumentException(message);
- throw e;
- }
else
{
PortSetting = port.Value;
@@ -1391,16 +1414,6 @@ internal bool UseDefaultWSManPort
///
private const string DefaultComputerName = "localhost";
- ///
- /// Maximum value for port
- ///
- private const int MaxPort = 0xFFFF;
-
- ///
- /// Minimum value for port
- ///
- private const int MinPort = 0;
-
///
/// String that represents the local host Uri
///
@@ -1852,6 +1865,15 @@ private string KeyFilePath
set;
}
+ ///
+ /// Port for connection
+ ///
+ private int Port
+ {
+ get;
+ set;
+ }
+
#endregion
#region Constructors
@@ -1878,6 +1900,25 @@ public SSHConnectionInfo(
this.UserName = userName;
this.ComputerName = computerName;
this.KeyFilePath = keyFilePath;
+ this.Port = DefaultPort;
+ }
+
+ ///
+ /// Constructor
+ ///
+ /// User Name
+ /// Computer Name
+ /// Key File Path
+ /// Port number for connection (default 22)
+ public SSHConnectionInfo(
+ string userName,
+ string computerName,
+ string keyFilePath,
+ int port) : this(userName, computerName, keyFilePath)
+ {
+ ValidatePortInRange(port);
+
+ this.Port = (port != 0) ? port : DefaultPort;
}
#endregion
@@ -1930,6 +1971,7 @@ internal override RunspaceConnectionInfo InternalCopy()
newCopy.ComputerName = this.ComputerName;
newCopy.UserName = this.UserName;
newCopy.KeyFilePath = this.KeyFilePath;
+ newCopy.Port = this.Port;
return newCopy;
}
@@ -2004,14 +2046,14 @@ internal System.Diagnostics.Process StartSSHProcess(
}
arguments = (string.IsNullOrEmpty(domainName)) ?
- string.Format(CultureInfo.InvariantCulture, @"-i ""{0}"" {1}@{2} -s powershell", this.KeyFilePath, userName, this.ComputerName) :
- string.Format(CultureInfo.InvariantCulture, @"-i ""{0}"" -l {1}@{2} {3} -s powershell", this.KeyFilePath, userName, domainName, this.ComputerName);
+ string.Format(CultureInfo.InvariantCulture, @"-i ""{0}"" {1}@{2} -p {3} -s powershell", this.KeyFilePath, userName, this.ComputerName, this.Port) :
+ string.Format(CultureInfo.InvariantCulture, @"-i ""{0}"" -l {1}@{2} {3} -p {4} -s powershell", this.KeyFilePath, userName, domainName, this.ComputerName, this.Port);
}
else
{
arguments = (string.IsNullOrEmpty(domainName)) ?
- string.Format(CultureInfo.InvariantCulture, @"{0}@{1} -s powershell", userName, this.ComputerName) :
- string.Format(CultureInfo.InvariantCulture, @"-l {0}@{1} {2} -s powershell", userName, domainName, this.ComputerName);
+ string.Format(CultureInfo.InvariantCulture, @"{0}@{1} -p {2} -s powershell", userName, this.ComputerName, this.Port) :
+ string.Format(CultureInfo.InvariantCulture, @"-l {0}@{1} {2} -p {3} -s powershell", userName, domainName, this.ComputerName, this.Port);
}
System.Diagnostics.ProcessStartInfo startInfo = new System.Diagnostics.ProcessStartInfo(
@@ -2037,7 +2079,16 @@ private string GetCurrentUserName()
#endif
}
-#endregion
+ #endregion
+
+ #region Constants
+
+ ///
+ /// Default value for port
+ ///
+ private const int DefaultPort = 22;
+
+ #endregion
#region SSH Process Creation
diff --git a/test/powershell/engine/Remoting/SSHRemotingAPI.Tests.ps1 b/test/powershell/engine/Remoting/SSHRemotingAPI.Tests.ps1
index 3dfbd866a8c..eeb71b4436f 100644
--- a/test/powershell/engine/Remoting/SSHRemotingAPI.Tests.ps1
+++ b/test/powershell/engine/Remoting/SSHRemotingAPI.Tests.ps1
@@ -1,41 +1,68 @@
+Import-Module $PSScriptRoot\..\..\Common\Test.Helpers.psm1
+
Describe "SSH Remoting API Tests" -Tags "Feature" {
Context "SSHConnectionInfo Class Tests" {
+ AfterEach {
+ if ($rs -ne $null) {
+ $rs.Dispose()
+ }
+ }
+
It "SSHConnectionInfo constructor should throw null argument exception for null HostName parameter" {
+ { [System.Management.Automation.Runspaces.SSHConnectionInfo]::new(
+ "UserName",
+ [System.Management.Automation.Internal.AutomationNull]::Value,
+ [System.Management.Automation.Internal.AutomationNull]::Value,
+ 0) } | ShouldBeErrorId "PSArgumentNullException"
+ }
+
+ It "SSHConnectionInfo should throw file not found exception for invalid key file path" {
+
try
{
- [System.Management.Automation.Runspaces.SSHConnectionInfo]::new(
+ $sshConnectionInfo = [System.Management.Automation.Runspaces.SSHConnectionInfo]::new(
"UserName",
- [System.Management.Automation.Internal.AutomationNull]::Value,
- [System.Management.Automation.Internal.AutomationNull]::Value)
+ "localhost",
+ "NoValidKeyFilePath",
+ 22)
+ $rs = [runspacefactory]::CreateRunspace($sshConnectionInfo)
+ $rs.Open()
+
throw "No Exception!"
}
catch
{
- $_.FullyQualifiedErrorId | Should Be "PSArgumentNullException"
+ $_.Exception.InnerException.InnerException | Should BeOfType "System.IO.FileNotFoundException"
}
}
- It "SSHConnectionInfo should throw file not found exception for invalid key file path" {
-
- try
+ It "SSHConnectionInfo should throw argument exception for invalid port (non 16bit uint)" {
+ try
{
$sshConnectionInfo = [System.Management.Automation.Runspaces.SSHConnectionInfo]::new(
"UserName",
"localhost",
- "NoValidKeyFilePath")
+ "ValidKeyFilePath",
+ 99999)
$rs = [runspacefactory]::CreateRunspace($sshConnectionInfo)
$rs.Open()
-
+
throw "No Exception!"
}
catch
{
- $_.Exception.InnerException.InnerException | Should BeOfType System.IO.FileNotFoundException
+ $expectedArgumentException = $_.Exception
+ if ($_.Exception.InnerException -ne $null)
+ {
+ $expectedArgumentException = $_.Exception.InnerException
+ }
+
+ $expectedArgumentException | Should BeOfType "System.ArgumentException"
}
}
}